说明
Mybatis 允许在映射语句执行过程中的某些点拦截对的调用。默认情况下,MyBatis 允许插件拦截以下方法调用:
拦截器 | 方法 | 说明 |
---|---|---|
Executor | update, query, flushStatements, commit, rollback, getTransaction, close, isClosed | 执行 SQL 的核心对象(增删改查、事务控制等),可以在这里打印 SQL 语句,监控 SQL 执行时间等 |
StatementHandler | prepare, parameterize, batch, update, query | 封装了 JDBC Statement 操作,是 SQL 语法的构建器,可以在这里对 SQL 语句进行修改 |
ParameterHandler | getParameterObject, setParameters | 处理 SQL 入参绑定 |
ResultSetHandler | handleResultSets, handleOutputParameters | 处理查询结果 |
注:拦截需求并不一定要跟拦截器类型严格对应。如,一般在StatementHandler
中修改SQL,但是也可以在ParameterHandler
修改SQL,只要可以在执行前获取、修改SQL即可。
代码
NotInjectSQL注解:
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 不需要修改SQL的方法的注解
*
* @author java_t_t
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface NotInjectSQL {
}
Desensitization注解:
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 需要脱密的字段的注解
*
* @author java_t_t
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Desensitization {
}
Executor拦截器示例:
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;
import java.util.Date;
import java.util.List;
@Slf4j
@Intercepts({
@Signature(
type = Executor.class, // 拦截的类
method = "update", // 拦截的方法。insert、update、delete都属于update方法
args = {MappedStatement.class, Object.class} // 方法参数
),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class})
})
public class MybatisExecutorInterceptor implements Interceptor {
private static final DateTimeFormatter LOCAL_DATE_TIME_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
private static final DateTimeFormatter LOCAL_DATE_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd");
private static final DateTimeFormatter LOCAL_TIME_FORMAT = DateTimeFormatter.ofPattern("HH:mm:ss");
private static final SimpleDateFormat SIMPLE_DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
@Override
public Object intercept(Invocation invocation) throws Throwable {
log.info("=====> Executor");
printSql(invocation);
long startTime = System.currentTimeMillis();
Object proceed = invocation.proceed();
log.info("===> Executor ===> SQL expend " + (System.currentTimeMillis() - startTime) + " ms");
return proceed;
}
private void printSql(Invocation invocation) {
Object[] queryArgs = invocation.getArgs();
MappedStatement mappedStatement = (MappedStatement) queryArgs[0];
// 获取BoundSql对象,包含带占位符'?'的SQL,以及对应的参数列表
BoundSql boundSql = mappedStatement.getBoundSql(queryArgs[1]);
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Configuration configuration = mappedStatement.getConfiguration();
// 替换空白符
String originalSql = boundSql.getSql().replaceAll("[\\s]+", " ");
String[] split = originalSql.split("\\?");
int index = 0;
StringBuffer sql = new StringBuffer(split[index++]);
StringBuffer sb = new StringBuffer();
if (parameterMappings.size() > 0 && parameterObject != null) {
if (configuration.getTypeHandlerRegistry().hasTypeHandler(parameterObject.getClass())) {
sql.append(String.format("\"%s\"", parameterObject)).append(index < split.length ? split[index++] : "");
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
Object obj = null;
if (metaObject.hasGetter(propertyName)) {
obj = metaObject.getValue(propertyName);
} else if (boundSql.hasAdditionalParameter(propertyName)) {
obj = boundSql.getAdditionalParameter(propertyName);
}
if (obj == null) {
sql.append("null").append(index < split.length ? split[index++] : "");
sb.append(propertyName).append(" = null, ");
continue;
}
String parameterValue;
String className = obj.getClass().getName();
switch (className) {
case "java.lang.String" -> parameterValue = "\"" + obj + "\"";
case "java.time.LocalDateTime" ->
parameterValue = String.format("STR_TO_DATE(\"%s\",\"%%Y-%%m-%%d %%H:%%i:%%s\")",
LOCAL_DATE_TIME_FORMAT.format((LocalDateTime) obj));
case "java.time.LocalDate" ->
parameterValue = String.format("STR_TO_DATE(\"%s\",\"%%Y-%%m-%%d\")",
LOCAL_DATE_FORMAT.format((LocalDate) obj));
case "java.time.LocalTime" ->
parameterValue = String.format("STR_TO_DATE(\"%s\",\"%%H:%%i:%%s\")",
LOCAL_TIME_FORMAT.format((LocalTime) obj));
case "java.util.Date" ->
parameterValue = String.format("STR_TO_DATE(\"%s\",\"%%Y-%%m-%%d %%H:%%i:%%s\")",
SIMPLE_DATE_FORMAT.format((Date) obj));
default -> parameterValue = obj.toString();
}
sql.append(parameterValue).append(index < split.length ? split[index++] : "");
sb.append(propertyName).append("=").append(obj).append("(").append(obj.getClass().getSimpleName()).append("), ");
}
}
if (sb.length() > 2) {
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
}
}
log.info("===> Executor ===> Original SQL: " + originalSql);
log.info("===> Executor ===> Real SQL: " + sql);
log.info("===> Executor ===> Real PARAM: " + sb);
}
}
StatementHandler拦截器示例:
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.JdbcParameter;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.Values;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.update.UpdateSet;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import java.io.StringReader;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MybatisStatementHandlerInterceptor implements Interceptor {
private static final String IS_DELETE = "IS_DELETE";
private static final String UPDATE_TIME = "UPDATE_TIME";
private static final String CREATE_TIME = "CREATE_TIME";
private static final String SQL_PATH = "delegate.boundSql.sql";
private static final Column INJECT_CURRENT_TIMESTAMP = new Column("CURRENT_TIMESTAMP()");
private static final EqualsTo INJECT_IS_DELETE = new EqualsTo()
.withLeftExpression(new Column(IS_DELETE))
.withRightExpression(new LongValue(0));
private static final CCJSqlParserManager SQL_PARSER = new CCJSqlParserManager();
@Override
public Object intercept(Invocation invocation) throws Throwable {
log.info("=====> StatementHandler");
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
log.info("===> StatementHandler ===> Original Sql:{}", metaObject.getValue(SQL_PATH).toString().replaceAll("[\\s]+", " "));
injectSql(metaObject);
log.info("===> StatementHandler ===> Modified Sql:{}", metaObject.getValue(SQL_PATH).toString().replaceAll("[\\s]+", " "));
return invocation.proceed();
/* 修改或添加查询参数
metaObject.setValue("delegate.parameterHandler.parameterObject.pageSize", 5);
metaObject.setValue("delegate.parameterHandler.parameterObject.addValue", "add_value");
// 关闭内存分页。内存分页(逻辑分页):把数据查询到内存中,再分页;物理分页:在sql语句中添加limit offset, size
metaObject.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
metaObject.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);*/
}
private void injectSql(MetaObject metaObject) {
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (hasAnnotation(mappedStatement)) {
return;
}
// mappedStatement.getSqlCommandType()的值是xml中的标签,跟实际的查询类型不一定完全符合,所以需要根据解析的实际结果决定注入方式
String originalSql = (String) metaObject.getValue(SQL_PATH);
try (StringReader sr = new StringReader(originalSql)) {
Statement parse = SQL_PARSER.parse(sr);
if (parse instanceof Insert) {
injectInsert((Insert) parse, metaObject);
} else if (parse instanceof Update) {
injectUpdate((Update) parse, metaObject);
} else if (parse instanceof Delete) {
injectDelete((Delete) parse, metaObject);
} else if (parse instanceof Select) {
injectSelect((Select) parse, metaObject);
}
} catch (JSQLParserException e) {
log.warn(e.getMessage());
}
}
private void injectInsert(Insert insert, MetaObject metaObject) {
ExpressionList<Column> columns = insert.getColumns();
Values values = insert.getValues();
if (columns == null || columns.isEmpty()
|| values == null || values.getExpressions() == null || values.getExpressions().isEmpty()) {
return;
}
boolean noUpdateTime = true;
boolean noCreateTime = true;
boolean noIsDelete = true;
for (Column column : columns) {
String columnName = column.getColumnName().replace("`", "").trim().toUpperCase(Locale.ROOT);
if (UPDATE_TIME.equals(columnName)) {
noUpdateTime = false;
}
if (CREATE_TIME.equals(columnName)) {
noCreateTime = false;
}
if (IS_DELETE.equals(columnName)) {
noIsDelete = false;
}
}
if (noUpdateTime) {
columns.add(new Column(UPDATE_TIME));
addValues(values.getExpressions(), INJECT_CURRENT_TIMESTAMP);
}
if (noCreateTime) {
columns.add(new Column(CREATE_TIME));
addValues(values.getExpressions(), INJECT_CURRENT_TIMESTAMP);
}
if (noIsDelete) {
columns.add((Column) INJECT_IS_DELETE.getLeftExpression());
addValues(values.getExpressions(), INJECT_IS_DELETE.getRightExpression());
}
metaObject.setValue(SQL_PATH, insert.toString());
}
private void addValues(List expressions, Expression expression) {
if (expressions == null) {
return;
}
Iterator iterator = expressions.iterator();
while (iterator.hasNext()) {
Object next = iterator.next();
if (next instanceof JdbcParameter) {
expressions.add(expression);
return;
}
if (next instanceof List) {
addValues((List) next, expression);
}
}
}
private void injectUpdate(Update update, MetaObject metaObject) {
Expression where = update.getWhere();
if (where == null) {
update.setWhere(INJECT_IS_DELETE);
} else if (!where.toString().toUpperCase(Locale.ROOT).contains(IS_DELETE)) {
update.setWhere(BinaryExpression.and(where, INJECT_IS_DELETE));
}
List<UpdateSet> updateSets = update.getUpdateSets();
if (updateSets == null) {
metaObject.setValue(SQL_PATH, update.toString());
return;
}
boolean noUpdateTime = true;
outer:
for (UpdateSet updateSet : updateSets) {
for (Column column : updateSet.getColumns()) {
if (UPDATE_TIME.equals(column.getColumnName().replace("`", "").trim().toUpperCase(Locale.ROOT))) {
noUpdateTime = false;
break outer;
}
}
}
if (noUpdateTime) {
UpdateSet updateSet = new UpdateSet();
updateSet.setColumns(new ExpressionList<>().withExpressions(new Column(UPDATE_TIME)));
updateSet.setValues(new ExpressionList().withExpressions(INJECT_CURRENT_TIMESTAMP));
updateSets.add(updateSet);
}
metaObject.setValue(SQL_PATH, update.toString());
}
private void injectDelete(Delete delete, MetaObject metaObject) {
Expression where = delete.getWhere();
if (where == null) {
delete.setWhere(INJECT_IS_DELETE);
} else if (!where.toString().toUpperCase(Locale.ROOT).contains(IS_DELETE)) {
delete.setWhere(BinaryExpression.and(where, INJECT_IS_DELETE));
}
metaObject.setValue(SQL_PATH, delete.toString());
}
private void injectSelect(Select select, MetaObject metaObject) {
int childNum = select.getASTNode().jjtGetNumChildren();
if (childNum == 1) {
PlainSelect plainSelect = select.getPlainSelect();
addWhere(plainSelect);
} else if (childNum > 1) {
SetOperationList operationList = select.getSetOperationList();
for (Select operationSelect : operationList.getSelects()) {
addWhere(operationSelect.getPlainSelect());
}
}
metaObject.setValue(SQL_PATH, select.toString());
}
private void addWhere(PlainSelect plainSelect) {
// 有join的多表查询不自动添加is_delete字段
if (plainSelect.getJoins() != null && plainSelect.getJoins().size() > 0) {
return;
}
Expression where = plainSelect.getWhere();
if (where == null) {
plainSelect.setWhere(INJECT_IS_DELETE);
} else if (!where.toString().toUpperCase(Locale.ROOT).contains(IS_DELETE)) {
plainSelect.setWhere(BinaryExpression.and(where, INJECT_IS_DELETE));
}
}
/**
* 判断方法上是否有@NotInjectSQL注解
* 原理是mappedStatement.id的命名规则是 ${package名}.${mapper类名}.${方法名}。可以获取类路径,再通过反射加载类
*
* @param mappedStatement mappedStatement
* @return true-有@NotInjectSQL注解/false-没有@NotInjectSQL注解
*/
private boolean hasAnnotation(MappedStatement mappedStatement) {
String sqlId = mappedStatement.getId();
int lastIndexOf = sqlId.lastIndexOf('.');
String className = sqlId.substring(0, lastIndexOf);
String methodName = sqlId.substring(lastIndexOf + 1);
try {
Class<?> mapperClass = Class.forName(className);
for (Method declaredMethod : mapperClass.getDeclaredMethods()) {
if (declaredMethod.getName().equals(methodName)) {
return declaredMethod.getAnnotation(NotInjectSQL.class) != null;
}
}
return false;
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
ParameterHandler拦截器示例:
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import java.sql.PreparedStatement;
import java.util.Map;
@Slf4j
@Intercepts({@Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class})})
public class MybatisParameterHandlerInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
log.info("=====> ParameterHandler");
ParameterHandler parameterHandler = (ParameterHandler) invocation.getTarget();
Object parameterObject = parameterHandler.getParameterObject();
if (parameterObject instanceof Map<?, ?>) {
Map<String, Object> parameterMap = (Map<String, Object>) parameterObject;
for (Map.Entry<String, Object> entry : parameterMap.entrySet()) {
if ("pageSize".equals(entry.getKey())) {
Object pageSize = parameterMap.get("pageSize");
if (pageSize != null && Long.parseLong(pageSize.toString()) > 10) {
// 修改查询参数
parameterMap.put("pageSize", 10);
}
}
}
}
return invocation.proceed();
}
}
ResultSetHandler拦截器示例:
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import java.lang.reflect.Field;
import java.sql.Statement;
import java.util.List;
import java.util.Locale;
@Slf4j
@Intercepts({@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})})
public class MybatisResultSetHandlerInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
log.info("=====> ResultSetHandler");
Object proceedResult = invocation.proceed();
if (proceedResult == null) {
return proceedResult;
}
if (proceedResult instanceof List<?>) {
List<?> list = (List<?>) proceedResult;
for (Object item : list) {
desensitization(item);
}
} else {
desensitization(proceedResult);
}
return proceedResult;
}
private void desensitization(Object obj) throws IllegalAccessException {
for (Field field : obj.getClass().getDeclaredFields()) {
// 对可能的密码字段脱敏
if (field.getName().toLowerCase(Locale.ROOT).contains("password")) {
field.setAccessible(true);
Object fieldValue = field.get(obj);
field.set(obj, getDesensitiveValue(fieldValue));
continue;
}
Desensitization annotation = field.getAnnotation(Desensitization.class);
if (annotation == null) {
continue;
}
field.setAccessible(true);
Object fieldValue = field.get(obj);
if (fieldValue == null) {
continue;
}
field.set(obj, getDesensitiveValue(fieldValue));
}
}
private static Object getDesensitiveValue(Object obj) {
if (obj instanceof String) {
String str = (String) obj;
if (str.length() <= 2) {
return "***";
} else {
return str.charAt(0) + "***" + str.charAt(str.length() - 1);
}
}
// 处理基本类型byte/short/int/long/float/double/boolean/char
if (obj instanceof Byte || obj instanceof Short || obj instanceof Integer || obj instanceof Long || obj instanceof Float || obj instanceof Double) {
return (byte) 0;
}
if (obj instanceof Boolean) {
return false;
}
if (obj instanceof Character) {
return '*';
}
return null;
}
}
如果需要解析SQL,需要在pom.xml
中引入jsqlparser
:
<!-- https://mvnrepository.com/artifact/com.github.jsqlparser/jsqlparser -->
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>5.2</version>
</dependency>
注册插件
- 使用Spring Boot默认连接池,使用配置文件注册:
1.1. 在resources/mybatis
目录下添加mybatis-config.xml
文件:
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE configuration
PUBLIC "-//mybatis.org//DTD Config 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-config.dtd">
<configuration>
<plugins>
<plugin interceptor="com.example.study.config.MybatisExecutorInterceptor"/>
<plugin interceptor="com.example.study.config.MybatisStatementHandlerInterceptor"/>
<plugin interceptor="com.example.study.config.MybatisParameterHandlerInterceptor"/>
<plugin interceptor="com.example.study.config.MybatisResultSetHandlerInterceptor"/>
</plugins>
</configuration>
1.2. 在application.properties
添加以下配置:
mybatis.config-location=classpath:mybatis/mybatis-config.xml
- 使用Druid连接池,通过配置类注册:
在配置类中,SqlSessionFactoryBean
对象调用setPlugins()
方法注册(参考SpringBoot连接Mysql数据库)
@Bean
public SqlSessionFactory sqlSessionFactory(@Qualifier("dataSource") DataSource dataSource) throws Exception {
SqlSessionFactoryBean factoryBean = new SqlSessionFactoryBean();
factoryBean.setDataSource(dataSource);
factoryBean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(mapperLocations));
factoryBean.setPlugins(new MybatisExecutorInterceptor(),
new MybatisStatementHandlerInterceptor(),
new MybatisParameterHandlerInterceptor(),
new MybatisResultSetHandlerInterceptor());
return factoryBean.getObject();
}
注:以上拦截器可以放到一个类中,在@Intercepts
中添加对应的@Signature
,然后在intercept(invocation)
方法中,用invocation.getTarget() instanceof
(如invocation.getTarget() instanceof Executor
、invocation.getTarget() instanceof StatementHandler
)判断当前是哪个拦截器,然后执行对应的拦截逻辑。