Browse Source

【代码优化】全局:复用 MyBatis Plus 数据权限插件,简化项目的数据权限实现

YunaiV 8 months ago
parent
commit
bea66e5c6f

+ 8 - 6
yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/config/YudaoDataPermissionAutoConfiguration.java

@@ -1,12 +1,13 @@
 package cn.iocoder.yudao.framework.datapermission.config;
 
 import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionAnnotationAdvisor;
-import cn.iocoder.yudao.framework.datapermission.core.db.DataPermissionDatabaseInterceptor;
+import cn.iocoder.yudao.framework.datapermission.core.db.DataPermissionRuleHandler;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactoryImpl;
 import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
 import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
+import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
 import org.springframework.boot.autoconfigure.AutoConfiguration;
 import org.springframework.context.annotation.Bean;
 
@@ -26,14 +27,15 @@ public class YudaoDataPermissionAutoConfiguration {
     }
 
     @Bean
-    public DataPermissionDatabaseInterceptor dataPermissionDatabaseInterceptor(MybatisPlusInterceptor interceptor,
-                                                                               DataPermissionRuleFactory ruleFactory) {
-        // 创建 DataPermissionDatabaseInterceptor 拦截器
-        DataPermissionDatabaseInterceptor inner = new DataPermissionDatabaseInterceptor(ruleFactory);
+    public DataPermissionRuleHandler dataPermissionRuleHandler(MybatisPlusInterceptor interceptor,
+                                                               DataPermissionRuleFactory ruleFactory) {
+        // 创建 DataPermissionInterceptor 拦截器
+        DataPermissionRuleHandler handler = new DataPermissionRuleHandler(ruleFactory);
+        DataPermissionInterceptor inner = new DataPermissionInterceptor(handler);
         // 添加到 interceptor 中
         // 需要加在首个,主要是为了在分页插件前面。这个是 MyBatis Plus 的规定
         MyBatisUtils.addInterceptor(interceptor, inner, 0);
-        return inner;
+        return handler;
     }
 
     @Bean

+ 0 - 297
yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptor.java

@@ -1,297 +0,0 @@
-package cn.iocoder.yudao.framework.datapermission.core.db;
-
-import cn.hutool.core.collection.CollUtil;
-import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
-import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
-import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
-import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
-import com.baomidou.mybatisplus.extension.plugins.inner.BaseMultiTableInnerInterceptor;
-import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
-import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
-import lombok.Getter;
-import lombok.RequiredArgsConstructor;
-import net.sf.jsqlparser.expression.Expression;
-import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
-import net.sf.jsqlparser.schema.Table;
-import net.sf.jsqlparser.statement.delete.Delete;
-import net.sf.jsqlparser.statement.select.Select;
-import net.sf.jsqlparser.statement.select.WithItem;
-import net.sf.jsqlparser.statement.update.Update;
-import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.session.ResultHandler;
-import org.apache.ibatis.session.RowBounds;
-
-import java.sql.Connection;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-
-/**
- * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
- * 主要的 SQL 重写方法,可见 {@link #buildTableExpression(Table, Expression, String)} 方法
- *
- * 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
- * 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
- *
- * 为什么不直接基于 {@link DataPermissionInterceptor} 实现?因为它是 MyBatis Plus 3.5.2 版本才出来,当时 yudao 已经实现数据权限了!
- *
- * @author 芋道源码
- */
-@RequiredArgsConstructor
-public class DataPermissionDatabaseInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
-
-    private final DataPermissionRuleFactory ruleFactory;
-
-    @Getter
-    private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
-
-    @Override // SELECT 场景
-    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
-        // 获得 Mapper 对应的数据权限的规则
-        List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
-        if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
-            return;
-        }
-
-        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
-        try {
-            // 初始化上下文
-            ContextHolder.init(rules);
-            // 处理 SQL
-            mpBs.sql(parserSingle(mpBs.sql(), null));
-        } finally {
-            // 添加是否需要重写的缓存
-            addMappedStatementCache(ms);
-            // 清空上下文
-            ContextHolder.clear();
-        }
-    }
-
-    @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
-    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
-        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
-        MappedStatement ms = mpSh.mappedStatement();
-        SqlCommandType sct = ms.getSqlCommandType();
-        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
-            // 获得 Mapper 对应的数据权限的规则
-            List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
-            if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
-                return;
-            }
-
-            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
-            try {
-                // 初始化上下文
-                ContextHolder.init(rules);
-                // 处理 SQL
-                mpBs.sql(parserMulti(mpBs.sql(), null));
-            } finally {
-                // 添加是否需要重写的缓存
-                addMappedStatementCache(ms);
-                // 清空上下文
-                ContextHolder.clear();
-            }
-        }
-    }
-
-    @Override
-    protected void processSelect(Select select, int index, String sql, Object obj) {
-        final String whereSegment = (String) obj;
-        processSelectBody(select, whereSegment);
-        List<WithItem> withItemsList = select.getWithItemsList();
-        if (!CollectionUtils.isEmpty(withItemsList)) {
-            withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
-        }
-    }
-
-    /**
-     * update 语句处理
-     */
-    @Override
-    protected void processUpdate(Update update, int index, String sql, Object obj) {
-        final Expression sqlSegment = getUpdateOrDeleteExpression(update.getTable(), update.getWhere(), (String) obj);
-        if (null != sqlSegment) {
-            update.setWhere(sqlSegment);
-        }
-    }
-
-    /**
-     * delete 语句处理
-     */
-    @Override
-    protected void processDelete(Delete delete, int index, String sql, Object obj) {
-        final Expression sqlSegment = getUpdateOrDeleteExpression(delete.getTable(), delete.getWhere(), (String) obj);
-        if (null != sqlSegment) {
-            delete.setWhere(sqlSegment);
-        }
-    }
-
-    // ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
-
-    protected Expression getUpdateOrDeleteExpression(final Table table, final Expression where, final String whereSegment) {
-        return andExpression(table, where, whereSegment);
-    }
-
-    /**
-     * 构建指定表的数据权限的 Expression 过滤条件
-     *
-     * @param table 表
-     * @return Expression 过滤条件
-     */
-    @Override
-    public Expression buildTableExpression(Table table, Expression where, String whereSegment) {
-        // 生成条件
-        Expression allExpression = null;
-        for (DataPermissionRule rule : ContextHolder.getRules()) {
-            // 判断表名是否匹配
-            String tableName = MyBatisUtils.getTableName(table);
-            if (!rule.getTableNames().contains(tableName)) {
-                continue;
-            }
-            // 如果有匹配的规则,说明可重写。
-            // 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
-            // 这样导致第一次无 value,被标记成无需重写;但是第二次有 value,此时会需要重写。
-            ContextHolder.setRewrite(true);
-
-            // 单条规则的条件
-            Expression oneExpress = rule.getExpression(tableName, table.getAlias());
-            if (oneExpress == null){
-                continue;
-            }
-            // 拼接到 allExpression 中
-            allExpression = allExpression == null ? oneExpress
-                    : new AndExpression(allExpression, oneExpress);
-        }
-
-        return allExpression;
-    }
-
-    /**
-     * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
-     *
-     * @param ms MappedStatement
-     */
-    private void addMappedStatementCache(MappedStatement ms) {
-        if (ContextHolder.getRewrite()) {
-            return;
-        }
-        // 无重写,进行添加
-        mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
-    }
-
-    /**
-     * SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
-     *
-     * @author 芋道源码
-     */
-    static final class ContextHolder {
-
-        /**
-         * 该 {@link MappedStatement} 对应的规则
-         */
-        private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
-        /**
-         * SQL 是否进行重写
-         */
-        private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
-
-        public static void init(List<DataPermissionRule> rules) {
-            RULES.set(rules);
-            REWRITE.set(false);
-        }
-
-        public static void clear() {
-            RULES.remove();
-            REWRITE.remove();
-        }
-
-        public static boolean getRewrite() {
-            return REWRITE.get();
-        }
-
-        public static void setRewrite(boolean rewrite) {
-            REWRITE.set(rewrite);
-        }
-
-        public static List<DataPermissionRule> getRules() {
-            return RULES.get();
-        }
-
-    }
-
-    /**
-     * {@link MappedStatement} 缓存
-     * 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
-     * 如果无效,则可以避免 SQL 的解析,加快速度
-     *
-     * @author 芋道源码
-     */
-    static final class MappedStatementCache {
-
-        /**
-         * 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
-         *
-         * value:{@link MappedStatement#getId()} 编号
-         */
-        @Getter
-        private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
-
-        /**
-         * 判断是否无需重写
-         * ps:虽然有点中文式英语,但是容易读懂即可
-         *
-         * @param ms MappedStatement
-         * @param rules 数据权限规则数组
-         * @return 是否无需重写
-         */
-        public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
-            // 如果规则为空,说明无需重写
-            if (CollUtil.isEmpty(rules)) {
-                return true;
-            }
-            // 任一规则不在 noRewritableMap 中,则说明可能需要重写
-            for (DataPermissionRule rule : rules) {
-                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
-                if (!CollUtil.contains(mappedStatementIds, ms.getId())) {
-                    return false;
-                }
-            }
-            return true;
-        }
-
-        /**
-         * 添加无需重写的 MappedStatement
-         *
-         * @param ms MappedStatement
-         * @param rules 数据权限规则数组
-         */
-        public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
-            for (DataPermissionRule rule : rules) {
-                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
-                if (CollUtil.isNotEmpty(mappedStatementIds)) {
-                    mappedStatementIds.add(ms.getId());
-                } else {
-                    noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
-                }
-            }
-        }
-
-        /**
-         * 清空缓存
-         * 目前主要提供给单元测试
-         */
-        public void clear() {
-            noRewritableMappedStatements.clear();
-        }
-
-    }
-
-}

+ 57 - 0
yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandler.java

@@ -0,0 +1,57 @@
+package cn.iocoder.yudao.framework.datapermission.core.db;
+
+import cn.hutool.core.collection.CollUtil;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
+import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
+import lombok.RequiredArgsConstructor;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
+import net.sf.jsqlparser.schema.Table;
+
+import java.util.List;
+
+/**
+ * 基于 {@link DataPermissionRule} 的数据权限处理器
+ *
+ * 它的底层,是基于 MyBatis Plus 的 <a href="https://baomidou.com/plugins/data-permission/">数据权限插件</a>
+ * 核心原理:它会在 SQL 执行前拦截 SQL 语句,并根据用户权限动态添加权限相关的 SQL 片段。这样,只有用户有权限访问的数据才会被查询出来
+ *
+ * @author 芋道源码
+ */
+@RequiredArgsConstructor
+public class DataPermissionRuleHandler implements MultiDataPermissionHandler {
+
+    private final DataPermissionRuleFactory ruleFactory;
+
+    @Override
+    public Expression getSqlSegment(Table table, Expression where, String mappedStatementId) {
+        // 获得 Mapper 对应的数据权限的规则
+        List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(mappedStatementId);
+        if (CollUtil.isEmpty(rules)) {
+            return null;
+        }
+
+        // 生成条件
+        Expression allExpression = null;
+        for (DataPermissionRule rule : rules) {
+            // 判断表名是否匹配
+            String tableName = MyBatisUtils.getTableName(table);
+            if (!rule.getTableNames().contains(tableName)) {
+                continue;
+            }
+
+            // 单条规则的条件
+            Expression oneExpress = rule.getExpression(tableName, table.getAlias());
+            if (oneExpress == null) {
+                continue;
+            }
+            // 拼接到 allExpression 中
+            allExpression = allExpression == null ? oneExpress
+                    : new AndExpression(allExpression, oneExpress);
+        }
+        return allExpression;
+    }
+
+}

+ 0 - 190
yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest.java

@@ -1,190 +0,0 @@
-package cn.iocoder.yudao.framework.datapermission.core.db;
-
-import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
-import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
-import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
-import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
-import net.sf.jsqlparser.expression.Alias;
-import net.sf.jsqlparser.expression.Expression;
-import net.sf.jsqlparser.expression.LongValue;
-import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
-import net.sf.jsqlparser.schema.Column;
-import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.mockito.InjectMocks;
-import org.mockito.Mock;
-import org.mockito.MockedStatic;
-
-import java.sql.Connection;
-import java.util.*;
-
-import static java.util.Collections.singletonList;
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
-
-/**
- * {@link DataPermissionDatabaseInterceptor} 的单元测试
- * 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
- * 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
- * 以及在这个过程中,ContextHolder 和 MappedStatementCache
- *
- * @author 芋道源码
- */
-public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
-
-    @InjectMocks
-    private DataPermissionDatabaseInterceptor interceptor;
-
-    @Mock
-    private DataPermissionRuleFactory ruleFactory;
-
-    @BeforeEach
-    public void setUp() {
-        // 清理上下文
-        DataPermissionDatabaseInterceptor.ContextHolder.clear();
-        // 清空缓存
-        interceptor.getMappedStatementCache().clear();
-    }
-
-    @Test // 不存在规则,且不匹配
-    public void testBeforeQuery_withoutRule() {
-        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
-            // 准备参数
-            MappedStatement mappedStatement = mock(MappedStatement.class);
-            BoundSql boundSql = mock(BoundSql.class);
-
-            // 调用
-            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
-            // 断言
-            pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
-        }
-    }
-
-    @Test // 存在规则,且不匹配
-    public void testBeforeQuery_withMatchRule() {
-        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
-            // 准备参数
-            MappedStatement mappedStatement = mock(MappedStatement.class);
-            BoundSql boundSql = mock(BoundSql.class);
-            // mock 方法(数据权限)
-            when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
-                    .thenReturn(singletonList(new DeptDataPermissionRule()));
-            // mock 方法(MPBoundSql)
-            PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
-            pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
-            // mock 方法(SQL)
-            String sql = "select * from t_user where id = 1";
-            when(mpBs.sql()).thenReturn(sql);
-            // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
-
-            // 调用
-            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
-            // 断言
-            verify(mpBs, times(1)).sql(
-                    eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
-            // 断言缓存
-            assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
-        }
-    }
-
-    @Test // 存在规则,但不匹配
-    public void testBeforeQuery_withoutMatchRule() {
-        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
-            // 准备参数
-            MappedStatement mappedStatement = mock(MappedStatement.class);
-            BoundSql boundSql = mock(BoundSql.class);
-            // mock 方法(数据权限)
-            when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
-                    .thenReturn(singletonList(new DeptDataPermissionRule()));
-            // mock 方法(MPBoundSql)
-            PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
-            pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
-            // mock 方法(SQL)
-            String sql = "select * from t_role where id = 1";
-            when(mpBs.sql()).thenReturn(sql);
-            // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
-
-            // 调用
-            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
-            // 断言
-            verify(mpBs, times(1)).sql(
-                    eq("SELECT * FROM t_role WHERE id = 1"));
-            // 断言缓存
-            assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
-        }
-    }
-
-    @Test
-    public void testAddNoRewritable() {
-        // 准备参数
-        MappedStatement ms = mock(MappedStatement.class);
-        List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
-        // mock 方法
-        when(ms.getId()).thenReturn("selectById");
-
-        // 调用
-        interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
-        // 断言
-        Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements =
-                interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
-        assertEquals(1, noRewritableMappedStatements.size());
-        assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
-    }
-
-    @Test
-    public void testNoRewritable() {
-        // 准备参数
-        MappedStatement ms = mock(MappedStatement.class);
-        // mock 方法
-        when(ms.getId()).thenReturn("selectById");
-        // mock 数据
-        List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
-        interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
-
-        // 场景一,rules 为空
-        assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
-        // 场景二,rules 非空,可重写
-        assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
-        // 场景三,rule 非空,不可重写
-        assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
-    }
-
-    private static class DeptDataPermissionRule implements DataPermissionRule {
-
-        private static final String COLUMN = "dept_id";
-
-        @Override
-        public Set<String> getTableNames() {
-            return SetUtils.asSet("t_user");
-        }
-
-        @Override
-        public Expression getExpression(String tableName, Alias tableAlias) {
-            Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
-            LongValue value = new LongValue(100L);
-            return new EqualsTo(column, value);
-        }
-
-    }
-
-    private static class EmptyDataPermissionRule implements DataPermissionRule {
-
-        @Override
-        public Set<String> getTableNames() {
-            return Collections.emptySet();
-        }
-
-        @Override
-        public Expression getExpression(String tableName, Alias tableAlias) {
-            return null;
-        }
-
-    }
-
-}

+ 12 - 6
yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java → yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandlerTest.java

@@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
 import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
 import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
+import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
 import net.sf.jsqlparser.expression.Alias;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.LongValue;
@@ -22,24 +23,30 @@ import java.util.Set;
 
 import static cn.iocoder.yudao.framework.common.util.collection.SetUtils.asSet;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
 
 /**
- * {@link DataPermissionDatabaseInterceptor} 的单元测试
+ * {@link DataPermissionRuleHandler} 的单元测试
  * 主要复用了 MyBatis Plus 的 TenantLineInnerInterceptorTest 的单元测试
  * 不过它的单元测试不是很规范,考虑到是复用的,所以暂时不进行修改~
  *
  * @author 芋道源码
  */
-public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest {
+public class DataPermissionRuleHandlerTest extends BaseMockitoUnitTest {
 
     @InjectMocks
-    private DataPermissionDatabaseInterceptor interceptor;
+    private DataPermissionRuleHandler handler;
 
     @Mock
     private DataPermissionRuleFactory ruleFactory;
 
+    private DataPermissionInterceptor interceptor;
+
     @BeforeEach
     public void setUp() {
+        interceptor = new DataPermissionInterceptor(handler);
+
         // 租户的数据权限规则
         DataPermissionRule tenantRule = new DataPermissionRule() {
 
@@ -78,8 +85,8 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
             }
 
         };
-        // 设置到上下文,保证
-        DataPermissionDatabaseInterceptor.ContextHolder.init(Arrays.asList(tenantRule, deptRule));
+        // 设置到上下文
+        when(ruleFactory.getDataPermissionRule(any())).thenReturn(Arrays.asList(tenantRule, deptRule));
     }
 
     @Test
@@ -448,7 +455,6 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
         assertEquals(targetSql, interceptor.parserSingle(sql, null));
     }
 
-
     // ========== 额外的测试 ==========
 
     @Test