目录

Mybatis 插件

Mybatis plugins

Mybatis 的插件其实就是个拦截器功能。它利用JDK动态代理和责任链设计模式的综合运用。采用责任链模式,通过动态代理组织多个拦截器,通过这些拦截器你可以做一些你想做的事。

在 Mybatis 中,插件机制提供了非常强大的扩展能力,在 sql 最终执行之前,提供了四个拦截点,支持不同场景的功能扩展

  • Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
  • ParameterHandler (getParameterObject, setParameters)
  • ResultSetHandler (handleResultSets, handleOutputParameters)
  • StatementHandler (prepare, parameterize, batch, update, query)

这些类中方法的细节可以通过查看每个方法的签名来发现,或者直接查看 MyBatis 发行包中的源代码。 如果你想做的不仅仅是监控方法的调用,那么你最好相当了解要重写的方法的行为。 因为在试图修改或重写已有方法的行为时,很可能会破坏 MyBatis 的核心模块。 这些都是更底层的类和方法,所以使用插件的时候要特别当心。

原理

Mybatis 的 代理类 Plugin 实现了 InvocationHandler 接口

当调用ParameterHandler,ResultSetHandler,StatementHandler,Executor的对象的时候,就会执行Plugin的invoke方法,Plugin在invoke方法中根据

@Intercepts的配置信息(方法名,参数等)动态判断是否需要拦截该方法.再然后使用需要拦截的方法Method封装成Invocation,并调用Interceptor的proceed方法。这样我们就达到了拦截目标方法的结果。

一些注意事项:

  • 不要定义过多的插件,代理嵌套过多,执行方法的时候,比较耗性能;
  • 拦截器实现类的 Interceptor#intercept 方法里最后不要忘了执行 invocation.proceed() 方法,否则多个拦截器情况下,执行链条会断掉;
  • 多个插件,需要在某个插件中分离目标对象。可以借助 SystemMetaObject 类来进行获取最后一层的 h 以及 target 属性的值。
1
2
3
4
5
6
7
MetaObject metaObject = SystemMetaObject.forObject(target);
while(metaObject.hasGetter("h")){
  Object h = metaObject.getValue("h");
  metaObject = SystemMetaObject.forObject(h);
}
Object obj = metaObject.getValue("target");
MetaObject targetObj = SystemMetaObject.forObject(obj);
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
public class Invocation {

  private final Object target;
  private final Method method;
  private final Object[] args;

  public Invocation(Object target, Method method, Object[] args) {
    this.target = target;
    this.method = method;
    this.args = args;
  }
}

通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定想要拦截的方法签名即可。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
public class Plugin implements InvocationHandler {

  private final Object target;
  private final Interceptor interceptor;
  private final Map<Class<?>, Set<Method>> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }
}
public interface Interceptor {
  // 拦截目标方法执行
  Object intercept(Invocation invocation) throws Throwable;
  // 封装目标对象,生成动态代理对象
  default Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }
  // 在Mybatis配置文件中指定一些属性
  default void setProperties(Properties properties) {
  }
}

public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }
}
public class Configuration {
  protected final InterceptorChain interceptorChain = new InterceptorChain();
  public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
  }

  public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
      ResultHandler resultHandler, BoundSql boundSql) {
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
  }

  public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
  }
  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }
}

自定义插件

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Slf4j
@Component
@Intercepts(value = {@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class ExecuteStatInterceptor implements Interceptor {
  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    // MetaObject 是 Mybatis 提供的一个用于访问对象属性的对象
        MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
        BoundSql sql = statement.getBoundSql(invocation.getArgs()[1]);

        long start = System.currentTimeMillis();
        List<ParameterMapping> list = sql.getParameterMappings();
        OgnlContext context = (OgnlContext) Ognl.createDefaultContext(sql.getParameterObject());
        List<Object> params = new ArrayList<>(list.size());
        for (ParameterMapping mapping : list) {
            params.add(Ognl.getValue(Ognl.parseExpression(mapping.getProperty()), context, context.getRoot()));
        }
        try {
            return invocation.proceed();
        } finally {
            System.out.println("------------> sql: " + sql.getSql() + "\n------------> args: " + params + "------------> cost: " + (System.currentTimeMillis() - start));
        }
  }
}
1
2
3
4
5
6
<!-- mybatis-config.xml -->
<plugins>
  <plugin interceptor="org.ynthm.demo.mybatis.ExecuteStatInterceptor">
    <property name="someProperty" value="100"/>
  </plugin>
</plugins>

如果是 Spring Boot Interceptor 的实现定义成 Bean 就可以了

可以通过 SqlSessionFactory 设置

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
@Bean(name = "sqlSessionFactory")
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
    SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
    bean.setDataSource(dataSource);
    bean.setMapperLocations(
            // 设置mybatis的xml所在位置,这里使用mybatis注解方式,没有配置xml文件
            new PathMatchingResourcePatternResolver().getResources("classpath*:mapper/*.xml"));
    // 注册typehandler,供全局使用
    bean.setTypeHandlers(new Timestamp2LongHandler());
    bean.setPlugins(new ExecuteStatInterceptor());
    return bean.getObject();
}

自定义分页插件

  • sql 分页 limit offset
  • Mybatis 拦截器分页
  • RowBounds 分页 数据量小时,RowBounds不失为一种好办法。
1
2
3
4
5
@Override
    @Transactional(isolation = Isolation.READ_COMMITTED, propagation = Propagation.SUPPORTS)
    public List<RoleBean> queryRolesByPage(String roleName, int start, int limit) {
        return roleDao.queryRolesByPage(roleName, new RowBounds(start, limit));
    }

Mybatis 使用 RowBounds 对象进行分页,它是针对 ResultSet 结果集执行的内存分页,而非物理分页,先把数据都查出来,然后再做分页。

可以在 sql 内直接书写带有物理分页的参数来完成物理分页功能,也可以使用分页插件来完成物理分页。

在插件的拦截方法内拦截待执行的 sql,然后重写 sql(SQL 拼接 limit),根据 dialect 方言,添加对应的物理分页语句和物理分页参数,用到了技术 JDK 动态代理,用到了责任链设计模式。

MyBatis Pagination - PageHelper

1
2
// https://mvnrepository.com/artifact/com.github.pagehelper/pagehelper-spring-boot-starter
implementation 'com.github.pagehelper:pagehelper-spring-boot-starter:1.4.2'

内存分页的使用方式如下,首先在 Mapper 中添加 RowBounds 参数

首先我们需要自定义一个 RowBounds,因为 MyBatis 原生的 RowBounds 是内存分页,并且没有办法获取到总记录数(一般分页查询的时候我们还需要获取到总记录数),所以我们自定义 PageRowBounds,对原生的 RowBounds 功能进行增强。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
public class PageRowBounds extends RowBounds {
    private Long total;

    public PageRowBounds(int offset, int limit) {
        super(offset, limit);
    }

    public PageRowBounds() {
    }

    public Long getTotal() {
        return total;
    }

    public void setTotal(Long total) {
        this.total = total;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@Intercepts(@Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
))
public class PageInterceptor implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        if (rowBounds != RowBounds.DEFAULT) {
            Executor executor = (Executor) invocation.getTarget();
            BoundSql boundSql = ms.getBoundSql(parameterObject);
            Field additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
            additionalParametersField.setAccessible(true);
            Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
            if (rowBounds instanceof PageRowBounds) {
                MappedStatement countMs = newMappedStatement(ms, Long.class);
                CacheKey countKey = executor.createCacheKey(countMs, parameterObject, RowBounds.DEFAULT, boundSql);
                String countSql = "select count(*) from (" + boundSql.getSql() + ") temp";
                BoundSql countBoundSql = new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
                Set<String> keySet = additionalParameters.keySet();
                for (String key : keySet) {
                    countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
                }
                List<Object> countQueryResult = executor.query(countMs, parameterObject, RowBounds.DEFAULT, (ResultHandler) args[3], countKey, countBoundSql);
                Long count = (Long) countQueryResult.get(0);
                ((PageRowBounds) rowBounds).setTotal(count);
            }
            CacheKey pageKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql);
            pageKey.update("RowBounds");
            String pageSql = boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
            BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameterObject);
            Set<String> keySet = additionalParameters.keySet();
            for (String key : keySet) {
                pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
            }
            List list = executor.query(ms, parameterObject, RowBounds.DEFAULT, (ResultHandler) args[3], pageKey, pageBoundSql);
            return list;
        }
        //不需要分页,直接返回结果
        return invocation.proceed();
    }

    private MappedStatement newMappedStatement(MappedStatement ms, Class<Long> longClass) {
        MappedStatement.Builder builder = new MappedStatement.Builder(
                ms.getConfiguration(), ms.getId() + "_count", ms.getSqlSource(), ms.getSqlCommandType()
        );
        ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), longClass, new ArrayList<>(0)).build();
        builder.resource(ms.getResource())
                .fetchSize(ms.getFetchSize())
                .statementType(ms.getStatementType())
                .timeout(ms.getTimeout())
                .parameterMap(ms.getParameterMap())
                .resultSetType(ms.getResultSetType())
                .cache(ms.getCache())
                .flushCacheRequired(ms.isFlushCacheRequired())
                .useCache(ms.isUseCache())
                .resultMaps(Arrays.asList(resultMap));
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        return builder.build();
    }
}

可以只在第一页的时候查询总数。

附录