Mybatis 的插件其实就是个拦截器功能。它利用JDK动态代理和责任链设计模式的综合运用。采用责任链模式,通过动态代理组织多个拦截器,通过这些拦截器你可以做一些你想做的事。
在 Mybatis 中,插件机制提供了非常强大的扩展能力,在 sql 最终执行之前,提供了四个拦截点,支持不同场景的功能扩展
- Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
- ParameterHandler (getParameterObject, setParameters)
- ResultSetHandler (handleResultSets, handleOutputParameters)
- StatementHandler (prepare, parameterize, batch, update, query)
这些类中方法的细节可以通过查看每个方法的签名来发现,或者直接查看 MyBatis 发行包中的源代码。 如果你想做的不仅仅是监控方法的调用,那么你最好相当了解要重写的方法的行为。 因为在试图修改或重写已有方法的行为时,很可能会破坏 MyBatis 的核心模块。 这些都是更底层的类和方法,所以使用插件的时候要特别当心。
原理
Mybatis 的 代理类 Plugin 实现了 InvocationHandler 接口
当调用ParameterHandler,ResultSetHandler,StatementHandler,Executor的对象的时候,就会执行Plugin的invoke方法,Plugin在invoke方法中根据
@Intercepts的配置信息(方法名,参数等)动态判断是否需要拦截该方法.再然后使用需要拦截的方法Method封装成Invocation,并调用Interceptor的proceed方法。这样我们就达到了拦截目标方法的结果。
一些注意事项:
- 不要定义过多的插件,代理嵌套过多,执行方法的时候,比较耗性能;
- 拦截器实现类的
Interceptor#intercept
方法里最后不要忘了执行 invocation.proceed()
方法,否则多个拦截器情况下,执行链条会断掉;
- 多个插件,需要在某个插件中分离目标对象。可以借助 SystemMetaObject 类来进行获取最后一层的 h 以及 target 属性的值。
1
2
3
4
5
6
7
|
MetaObject metaObject = SystemMetaObject.forObject(target);
while(metaObject.hasGetter("h")){
Object h = metaObject.getValue("h");
metaObject = SystemMetaObject.forObject(h);
}
Object obj = metaObject.getValue("target");
MetaObject targetObj = SystemMetaObject.forObject(obj);
|
1
2
3
4
5
6
7
8
9
10
11
12
|
public class Invocation {
private final Object target;
private final Method method;
private final Object[] args;
public Invocation(Object target, Method method, Object[] args) {
this.target = target;
this.method = method;
this.args = args;
}
}
|
通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定想要拦截的方法签名即可。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
public class Plugin implements InvocationHandler {
private final Object target;
private final Interceptor interceptor;
private final Map<Class<?>, Set<Method>> signatureMap;
private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
this.target = target;
this.interceptor = interceptor;
this.signatureMap = signatureMap;
}
}
public interface Interceptor {
// 拦截目标方法执行
Object intercept(Invocation invocation) throws Throwable;
// 封装目标对象,生成动态代理对象
default Object plugin(Object target) {
return Plugin.wrap(target, this);
}
// 在Mybatis配置文件中指定一些属性
default void setProperties(Properties properties) {
}
}
public class InterceptorChain {
private final List<Interceptor> interceptors = new ArrayList<>();
public Object pluginAll(Object target) {
for (Interceptor interceptor : interceptors) {
target = interceptor.plugin(target);
}
return target;
}
public void addInterceptor(Interceptor interceptor) {
interceptors.add(interceptor);
}
public List<Interceptor> getInterceptors() {
return Collections.unmodifiableList(interceptors);
}
}
public class Configuration {
protected final InterceptorChain interceptorChain = new InterceptorChain();
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
return parameterHandler;
}
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
ResultHandler resultHandler, BoundSql boundSql) {
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
return resultSetHandler;
}
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
return statementHandler;
}
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
executorType = executorType == null ? defaultExecutorType : executorType;
executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) {
executor = new BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) {
executor = new ReuseExecutor(this, transaction);
} else {
executor = new SimpleExecutor(this, transaction);
}
if (cacheEnabled) {
executor = new CachingExecutor(executor);
}
executor = (Executor) interceptorChain.pluginAll(executor);
return executor;
}
}
|
自定义插件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
@Slf4j
@Component
@Intercepts(value = {@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class ExecuteStatInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
// MetaObject 是 Mybatis 提供的一个用于访问对象属性的对象
MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
BoundSql sql = statement.getBoundSql(invocation.getArgs()[1]);
long start = System.currentTimeMillis();
List<ParameterMapping> list = sql.getParameterMappings();
OgnlContext context = (OgnlContext) Ognl.createDefaultContext(sql.getParameterObject());
List<Object> params = new ArrayList<>(list.size());
for (ParameterMapping mapping : list) {
params.add(Ognl.getValue(Ognl.parseExpression(mapping.getProperty()), context, context.getRoot()));
}
try {
return invocation.proceed();
} finally {
System.out.println("------------> sql: " + sql.getSql() + "\n------------> args: " + params + "------------> cost: " + (System.currentTimeMillis() - start));
}
}
}
|
1
2
3
4
5
6
|
<!-- mybatis-config.xml -->
<plugins>
<plugin interceptor="org.ynthm.demo.mybatis.ExecuteStatInterceptor">
<property name="someProperty" value="100"/>
</plugin>
</plugins>
|
如果是 Spring Boot Interceptor 的实现定义成 Bean 就可以了
可以通过 SqlSessionFactory 设置
1
2
3
4
5
6
7
8
9
10
11
12
|
@Bean(name = "sqlSessionFactory")
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
bean.setDataSource(dataSource);
bean.setMapperLocations(
// 设置mybatis的xml所在位置,这里使用mybatis注解方式,没有配置xml文件
new PathMatchingResourcePatternResolver().getResources("classpath*:mapper/*.xml"));
// 注册typehandler,供全局使用
bean.setTypeHandlers(new Timestamp2LongHandler());
bean.setPlugins(new ExecuteStatInterceptor());
return bean.getObject();
}
|
自定义分页插件
- sql 分页 limit offset
- Mybatis 拦截器分页
- RowBounds 分页 数据量小时,RowBounds不失为一种好办法。
1
2
3
4
5
|
@Override
@Transactional(isolation = Isolation.READ_COMMITTED, propagation = Propagation.SUPPORTS)
public List<RoleBean> queryRolesByPage(String roleName, int start, int limit) {
return roleDao.queryRolesByPage(roleName, new RowBounds(start, limit));
}
|
Mybatis 使用 RowBounds 对象进行分页,它是针对 ResultSet 结果集执行的内存分页,而非物理分页,先把数据都查出来,然后再做分页。
可以在 sql 内直接书写带有物理分页的参数来完成物理分页功能,也可以使用分页插件来完成物理分页。
在插件的拦截方法内拦截待执行的 sql,然后重写 sql(SQL 拼接 limit),根据 dialect 方言,添加对应的物理分页语句和物理分页参数,用到了技术 JDK 动态代理,用到了责任链设计模式。
MyBatis Pagination - PageHelper
1
2
|
// https://mvnrepository.com/artifact/com.github.pagehelper/pagehelper-spring-boot-starter
implementation 'com.github.pagehelper:pagehelper-spring-boot-starter:1.4.2'
|
内存分页的使用方式如下,首先在 Mapper 中添加 RowBounds 参数
首先我们需要自定义一个 RowBounds,因为 MyBatis 原生的 RowBounds 是内存分页,并且没有办法获取到总记录数(一般分页查询的时候我们还需要获取到总记录数),所以我们自定义 PageRowBounds,对原生的 RowBounds 功能进行增强。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
public class PageRowBounds extends RowBounds {
private Long total;
public PageRowBounds(int offset, int limit) {
super(offset, limit);
}
public PageRowBounds() {
}
public Long getTotal() {
return total;
}
public void setTotal(Long total) {
this.total = total;
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
@Intercepts(@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
))
public class PageInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
RowBounds rowBounds = (RowBounds) args[2];
if (rowBounds != RowBounds.DEFAULT) {
Executor executor = (Executor) invocation.getTarget();
BoundSql boundSql = ms.getBoundSql(parameterObject);
Field additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
additionalParametersField.setAccessible(true);
Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
if (rowBounds instanceof PageRowBounds) {
MappedStatement countMs = newMappedStatement(ms, Long.class);
CacheKey countKey = executor.createCacheKey(countMs, parameterObject, RowBounds.DEFAULT, boundSql);
String countSql = "select count(*) from (" + boundSql.getSql() + ") temp";
BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
Set<String> keySet = additionalParameters.keySet();
for (String key : keySet) {
countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
List<Object> countQueryResult = executor.query(countMs, parameterObject, RowBounds.DEFAULT, (ResultHandler) args[3], countKey, countBoundSql);
Long count = (Long) countQueryResult.get(0);
((PageRowBounds) rowBounds).setTotal(count);
}
CacheKey pageKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql);
pageKey.update("RowBounds");
String pageSql = boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameterObject);
Set<String> keySet = additionalParameters.keySet();
for (String key : keySet) {
pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
List list = executor.query(ms, parameterObject, RowBounds.DEFAULT, (ResultHandler) args[3], pageKey, pageBoundSql);
return list;
}
//不需要分页,直接返回结果
return invocation.proceed();
}
private MappedStatement newMappedStatement(MappedStatement ms, Class<Long> longClass) {
MappedStatement.Builder builder = new MappedStatement.Builder(
ms.getConfiguration(), ms.getId() + "_count", ms.getSqlSource(), ms.getSqlCommandType()
);
ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), longClass, new ArrayList<>(0)).build();
builder.resource(ms.getResource())
.fetchSize(ms.getFetchSize())
.statementType(ms.getStatementType())
.timeout(ms.getTimeout())
.parameterMap(ms.getParameterMap())
.resultSetType(ms.getResultSetType())
.cache(ms.getCache())
.flushCacheRequired(ms.isFlushCacheRequired())
.useCache(ms.isUseCache())
.resultMaps(Arrays.asList(resultMap));
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
return builder.build();
}
}
|
可以只在第一页的时候查询总数。
附录