概述

顾名思义,就是一个拦截器,和springmvc的拦截器,servlet的过滤器差不多,就是在执行前拦了一道,里面可以做一些自己的事情。

平时用的mybatisPlus较多,直接以com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor
为例。其内部维护了一个拦截器List,在拦截的时候for循环依次去调用这些拦截器,这时候的执行顺序就是list中的元素下标;

业务上有时候需要做一些全局的权限隔离,逐行修改代码的方式有点麻烦,而且上新功能的时候还得手动加,属于重操作;这个时候可以利用mybatis的拦截器,配合开源的JSqlParser解析器,可以做到动态拼接条件;

mybatis plus的租户隔离也是这样做的;

mybatis plus:3.5.2

在这里插入图片描述

mybatis官方介绍中可以拦截的类型共4种:

  1. Executor(拦截执行器的方法),method=update包括了增删改,可以从MappedStatement中获取实际执行的是哪种类型
  2. ParameterHandler(拦截参数的处理)
  3. ResultSetHandler(拦截结果集的处理)
  4. StatementHandler(拦截Sql语法构建的处理)

平时业务中拦截较多的就是增删改查,经典的就是分页拦截器–查询拦截器。

mybatisPlus拦截器Demo

参考了mybatisPlus的com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor

写完拦截器后记得要放到mybatisPlus的拦截器集合中去。

如果要从拦截器中方便的获取参数,拦截mybatismapper方法简单一点,mybatisPluslambda方式的参数获取比较复杂,前期虽然写起来方面了,但是后期迭代要动这方面时就会很麻烦,也提醒大家要保持良好的封装,以及重要操作统一入口的习惯

package com.xxx.xxx.xxx;

import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.conditions.update.Update;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.mapper.Mapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;

import java.lang.reflect.Field;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;


public class MyInnerInterceptor implements InnerInterceptor {


    /**
     *
     * @param executor  Executor(可能是代理对象)
     * @param ms        MappedStatement
     * @param parameter parameter
     * @throws SQLException
     */
    @Override
    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        SqlCommandType sqlCommandType = ms.getSqlCommandType();
        String msId = ms.getId();

        if(!Objects.equals(SqlCommandType.UPDATE, sqlCommandType)){
            return;
        }

        //更新
        if (parameter instanceof Map) {
            Map<String, Object> map = (Map<String, Object>) parameter;

            //被更新的实体类的对象在mybatis这里都是用et做别名
            Object et = map.getOrDefault("et", null);

            //对应mapper的class
            final String className = getMapperClassName(msId);

            //mapper使用的更新方法
            final String methodName = getMapperMethodName(msId);

            //实体类的class
            Class<?> entityClass = getEntityClass(className);

            //获取实体类的字段信息
            TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);

            //当前实体类的属性集合
            List<TableFieldInfo> fieldList = tableInfo.getFieldList();

            // updateById(et), update(et, wrapper);
            if(Objects.nonNull(et)){

                try {


                    for (TableFieldInfo fieldInfo : fieldList) {

                        //field
                        Field field = fieldInfo.getField();

                        //获取column
                        String column = fieldInfo.getColumn();

                        //旧的value
                        Object oldValue = field.get(et);


                    }



                } catch (IllegalAccessException e) {
                    throw ExceptionUtils.mpe(e);
                }


                // update(LambdaUpdateWrapper) or update(UpdateWrapper)
            }else if (map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), "ew"))) {
				

                Object ew = map.get("ew");

                if (!(ew instanceof AbstractWrapper && ew instanceof Update)) {
                    return;
                }

                final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();


                for (TableFieldInfo fieldInfo : fieldList) {

                    //field
                    Field field = fieldInfo.getField();

                    //获取column
                    String column = fieldInfo.getColumn();


                    Wrapper<?> wrapper = (Wrapper<?>) ew;

                    //查询更新条件中指定的列名对应的值----下面方法只能获取where条件中等于的字段值
                    String valueKey = getValueKey(column, wrapper);

                    final Object conditionValue = paramNameValuePairs.get(valueKey);



                }


            }else if(map.entrySet().stream().noneMatch(t -> Objects.equals(t.getKey(), "ew"))){
                //监听 mapper 的update方法
                //比如mapper方法是 updateStatus(@Param("id") Long id, @Param("status") String status)
                //那么这里的map里会包含key = id和status对应的更新前的值,还有想要更新的值

                try {

                    Long id = null;
                    Object idObj = map.get(keyProperty);
                    if(idObj instanceof Integer){
                        id = (long)(Integer)idObj;
                    }else {
                        id = (Long) idObj;
                    }

                    String state = (String) map.get("state");

                    if(Objects.equals(keyProperty, "targetFillRecordId")){
                        TargetStatusLogUtil.recordStatus(id, 1, state);
                    }


                    if(Objects.equals(keyProperty, "groupCollectId")){
                        TargetStatusLogUtil.recordStatus(id, 2, state);
                    }


                    if(Objects.equals(keyProperty, "leaderCollectId")){
                        TargetStatusLogUtil.recordStatus(id, 3, state);
                    }


                } catch (Exception e) {
                    log.error(e.getMessage(), e);
                }


            }



        }


    }


    /**
     * 查询更新条件中指定的列名对应的值
     * 即查询 where xxx = value 这个条件中的 xxx 对应的 value
     * @param column
     * @param wrapper
     * @return
     */
    private String getValueKey(String column, Wrapper<?> wrapper){
        Pattern pattern = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + "MPGENVAL" + "\\d+)\\}");

        final NormalSegmentList segments = wrapper.getExpression().getNormal();

        String fieldName = null;
        ISqlSegment eq = null;
        String valueKey = null;

        for (ISqlSegment segment : segments) {

            String sqlSegment = segment.getSqlSegment();

                //如果字段已找到并且当前segment为EQ
            if(Objects.nonNull(fieldName) && segment == SqlKeyword.EQ){

                eq = segment;

                //如果EQ找到并且value已找到
            }else if(Objects.nonNull(fieldName) && Objects.nonNull(eq)){

                Matcher matcher = pattern.matcher(sqlSegment);
                if(matcher.matches()){
                    valueKey = matcher.group(1);
                    return valueKey;
                }


                //处理嵌套
            }else if (segment instanceof Wrapper){

                if(null != (valueKey = getValueKey(column, ((Wrapper<?>) segment)))){
                    return valueKey;
                }

                //判断字段是否是要查找字段
            }else if(Objects.equals(column, sqlSegment)){
                fieldName = sqlSegment;
            }

        }

        return valueKey;
    }






    private String getMapperMethodName(String msId){
        return msId.substring(msId.lastIndexOf('.') + 1);
    }

    private String getMapperClassName(String msId){
        return msId.substring(0, msId.lastIndexOf('.'));
    }

    /**
     * 通过mapper上实体类信息获取实体类class
     * @param className
     * @return
     */
    private Class<?> getEntityClass(String className){
        try {
            return ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
        } catch (ClassNotFoundException e) {
            throw ExceptionUtils.mpe(e);
        }
    }
}

注册拦截器

@Bean
public MyInnerInterceptor myInnerInterceptor(ApplicationContext applicationContext){
    MyInnerInterceptor myInnerInterceptor = new MyInnerInterceptor();
    MybatisPlusInterceptor bean = applicationContext.getBean(MybatisPlusInterceptor.class);
    bean.addInnerInterceptor(myInnerInterceptor);
    return myInnerInterceptor;
}

权限拦截器demo

搞个自定义注解,然后拦截器里判断下是否使用了注解
一般业务中已经有现成的拦截器在使用了,在原来的拦截器链中加上自己的拦截器就ok了;

可以参考mybatis plus的租户拦截器,里面对各种子查询都有做处理
com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor


//自定义注解
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface UserDataPermission {
}




//拦截器,实际处理交给对应的handler
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.*;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.SQLException;
import java.util.List;

@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class MyDataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {

    /**
     * 数据权限处理器
     */
    private MyDataPermissionHandler dataPermissionHandler;

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
            return;
        }
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        mpBs.sql(this.parserSingle(mpBs.sql(), ms.getId()));
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, (String) obj);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
        }
    }

    /**
     * 设置 where 条件
     *
     * @param plainSelect  查询对象
     * @param whereSegment 查询条件片段
     */
    private void setWhere(PlainSelect plainSelect, String whereSegment) {

        Expression sqlSegment = this.dataPermissionHandler.getSqlSegment(plainSelect, whereSegment);
        if (null != sqlSegment) {
            plainSelect.setWhere(sqlSegment);
        }
    }
}




//简单单表查询demo
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;

import java.lang.reflect.Method;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

@Slf4j
public class MyDataPermissionHandler {

    private RemoteRoleService remoteRoleService;
    private RemoteUserService remoteUserService;


    /**
     * 获取数据权限 SQL 片段
     *
     * @param plainSelect  查询对象
     * @param whereSegment 查询条件片段
     * @return JSqlParser 条件表达式
     */
    @SneakyThrows(Exception.class)
    public Expression getSqlSegment(PlainSelect plainSelect, String whereSegment) {
    	//自定义的方法,这里用的是service,一般用token中的权限比较方便快捷
        remoteRoleService = SpringUtil.getBean(RemoteRoleService.class);
        remoteUserService = SpringUtil.getBean(RemoteUserService.class);

        // 待执行 SQL Where 条件表达式
        Expression where = plainSelect.getWhere();
        if (where == null) {
            where = new HexValue(" 1 = 1 ");
        }
        log.info("开始进行权限过滤,where: {},mappedStatementId: {}", where, whereSegment);
        //获取mapper名称
        String className = whereSegment.substring(0, whereSegment.lastIndexOf("."));
        //获取方法名
        String methodName = whereSegment.substring(whereSegment.lastIndexOf(".") + 1);
        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名,无别名用表名,防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
        //获取当前mapper 的方法
        Method[] methods = Class.forName(className).getMethods();
        //遍历判断mapper 的所以方法,判断方法上是否有 UserDataPermission
        for (Method m : methods) {
            if (Objects.equals(m.getName(), methodName)) {
                UserDataPermission annotation = m.getAnnotation(UserDataPermission.class);
                if (annotation == null) {
                    return where;
                }
                // 1、当前用户Code
                User user = SecurityUtils.getUser();
                // 2、当前角色即角色或角色类型(可能多种角色)
                Set<String> roleTypeSet = remoteRoleService.currentUserRoleType();
                
                DataScope scopeType = DataPermission.getScope(roleTypeSet);
                switch (scopeType) {
                    // 查看全部
                    case ALL:
                        return where;
                    case DEPT:
                        // 查看本部门用户数据
                        // 创建IN 表达式
                        // 创建IN范围的元素集合
                        List<String> deptUserList = remoteUserService.listUserCodesByDeptCodes(user.getDeptCode());
                        // 把集合转变为JSQLParser需要的元素列表
                        ItemsList deptList = new ExpressionList(deptUserList.stream().map(StringValue::new).collect(Collectors.toList()));
                        InExpression inExpressiondept = new InExpression(new Column(mainTableName + ".creator_code"), deptList);
                        return new AndExpression(where, inExpressiondept);
                    case MYSELF:
                        // 查看自己的数据
                        //  = 表达式
                        EqualsTo usesEqualsTo = new EqualsTo();
                        usesEqualsTo.setLeftExpression(new Column(mainTableName + ".creator_code"));
                        usesEqualsTo.setRightExpression(new StringValue(user.getUserCode()));
                        return new AndExpression(where, usesEqualsTo);
                    default:
                        break;
                }
            }

        }
        //说明无权查看,
        where = new HexValue(" 1 = 2 ");
        return where;
    }
}





//注册拦截器
@Bean
public MyDataPermissionInterceptor myInterceptor(MybatisPlusInterceptor mybatisPlusInterceptor) {
    MyDataPermissionInterceptor sql = new MyDataPermissionInterceptor();
    sql.setDataPermissionHandler(new MyDataPermissionHandler());
	//拦截器实际执行的时候是按照list中的顺序调用
    List<InnerInterceptor> list = new ArrayList<>();
    // 添加数据权限插件
    list.add(sql);
    // 分页插件
    mybatisPlusInterceptor.setInterceptors(list);
    list.add(new PaginationInnerInterceptor(DbType.MYSQL));
    return sql;
}
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐