Selaa lähdekoodia

增加数据权限的 SQL 重写的上下文

YunaiV 3 vuotta sitten
vanhempi
commit
eda2b11dad

+ 2 - 2
yudao-framework/yudao-spring-boot-starter-data-permission/pom.xml

@@ -29,8 +29,8 @@
 
         <!-- Test 测试相关 -->
         <dependency>
-            <groupId>org.springframework.boot</groupId>
-            <artifactId>spring-boot-starter-test</artifactId>
+            <groupId>cn.iocoder.boot</groupId>
+            <artifactId>yudao-spring-boot-starter-test</artifactId>
             <scope>test</scope>
         </dependency>
     </dependencies>

+ 156 - 22
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java

@@ -1,11 +1,16 @@
 package cn.iocoder.yudao.framework.datapermission.core.interceptor;
 
-import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
+import cn.hutool.core.collection.CollUtil;
+import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import com.alibaba.ttl.TransmittableThreadLocal;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
+import lombok.RequiredArgsConstructor;
 import net.sf.jsqlparser.expression.*;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
@@ -24,33 +29,58 @@ import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
-import java.util.Collection;
-import java.util.Deque;
-import java.util.LinkedList;
-import java.util.List;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
 
+@RequiredArgsConstructor
 public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
 
-//    private TenantLineHandler tenantLineHandler;
+    private final DataPermissionRuleFactory ruleFactory;
 
-    @Override
+    private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
+
+    @Override // SELECT 场景
     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
-        // TODO 芋艿:这个判断,后续读懂下
-        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+        // 获得 Mapper 对应的数据权限的规则
+        List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
+        if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
+            return;
+        }
+
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
-        // TODO 芋艿:null=》DataScope
-        mpBs.sql(parserSingle(mpBs.sql(), null));
+        try {
+            // 初始化上下文
+            ContextHolder.init(rules);
+            // 处理 SQL
+            mpBs.sql(parserSingle(mpBs.sql(), null));
+        } finally {
+            addMappedStatementCache(ms);
+            ContextHolder.clear();
+        }
     }
 
-    @Override
+    @Override // 只处理 UPDATE / DELETE 场景
     public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
         PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
         MappedStatement ms = mpSh.mappedStatement();
         SqlCommandType sct = ms.getSqlCommandType();
         if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句
-            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+            // 获得 Mapper 对应的数据权限的规则
+            List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
+            if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
+                return;
+            }
+
             PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
-            mpBs.sql(parserMulti(mpBs.sql(), null));
+            try {
+                // 初始化上下文
+                ContextHolder.init(rules);
+                // 处理 SQL
+                mpBs.sql(parserMulti(mpBs.sql(), null));
+            } finally {
+                addMappedStatementCache(ms);
+                ContextHolder.clear();
+            }
         }
     }
 
@@ -87,10 +117,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
     @Override
     protected void processUpdate(Update update, int index, String sql, Object obj) {
         final Table table = update.getTable();
-        if (ignoreTable(table.getName())) {
-            // 过滤退出执行
-            return;
-        }
         update.setWhere(this.andExpression(table, update.getWhere()));
     }
 
@@ -99,10 +125,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
      */
     @Override
     protected void processDelete(Delete delete, int index, String sql, Object obj) {
-        if (ignoreTable(delete.getTable().getName())) {
-            // 过滤退出执行
-            return;
-        }
         delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
     }
 
@@ -378,4 +400,116 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
         return new LongValue(1L);
     }
 
+
+    /**
+     * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
+     *
+     * @param ms MappedStatement
+     */
+    private void addMappedStatementCache(MappedStatement ms) {
+        if (ContextHolder.getRewrite()) {
+            return;
+        }
+        // 有重写,进行添加
+        mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
+    }
+
+    /**
+     * SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
+     *
+     * @author 芋道源码
+     */
+    private static final class ContextHolder {
+
+        /**
+         * 该 {@link MappedStatement} 对应的规则
+         */
+        private static final ThreadLocal<List<DataPermissionRule>> RULES = new TransmittableThreadLocal<>();
+        /**
+         * SQL 是否进行重写
+         */
+        private static final ThreadLocal<Boolean> REWRITE = new TransmittableThreadLocal<>();
+
+        public static void init(List<DataPermissionRule> rules) {
+            RULES.set(rules);
+            REWRITE.set(false);
+        }
+
+        public static void clear() {
+            RULES.remove();
+            REWRITE.remove();
+        }
+
+        public static boolean getRewrite() {
+            return REWRITE.get();
+        }
+
+        public static void setRewrite(boolean rewrite) {
+            REWRITE.set(rewrite);
+        }
+
+        public static List<DataPermissionRule> getRules() {
+            return RULES.get();
+        }
+
+    }
+
+    /**
+     * {@link MappedStatement} 缓存
+     * 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
+     * 如果无效,则可以避免 SQL 的解析,加快速度
+     *
+     * @author 芋道源码
+     */
+    private static final class MappedStatementCache {
+
+        /**
+         * 无需重写的映射
+         *
+         * value:{@link MappedStatement#getId()} 编号
+         */
+        private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
+
+        /**
+         * 判断是否无需重写
+         * ps:虽然有点中文式英语,但是容易读懂即可
+         *
+         * @param ms MappedStatement
+         * @param rules 数据权限规则数组
+         * @return 是否无需重写
+         */
+        public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
+            // 如果规则为空,说明无需重写
+            if (CollUtil.isEmpty(rules)) {
+                return true;
+            }
+            // 任一规则不在 noRewritableMap 中,则说明可能需要重写
+            for (DataPermissionRule rule : rules) {
+                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
+                if (!CollUtil.contains(mappedStatementIds, ms.getId())) { // 不存在,则说明可能要重写
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        /**
+         * 添加无需重写的 MappedStatement
+         *
+         * @param ms MappedStatement
+         * @param rules 数据权限规则数组
+         */
+        public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
+            for (DataPermissionRule rule : rules) {
+                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
+                if (CollUtil.isNotEmpty(mappedStatementIds)) {
+                    mappedStatementIds.add(ms.getId());
+                } else {
+                    noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
+                }
+            }
+        }
+
+    }
+
 }

+ 22 - 3
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/rule/DataPermissionRuleFactory.java

@@ -1,9 +1,28 @@
 package cn.iocoder.yudao.framework.datapermission.core.rule;
 
+import java.util.List;
+
 /**
- * {@link DataPermissionRule} 工厂接口,提供如下能力:
- * 1. {@link DataPermissionRule} 的容器
- * 2. TODO 芋艿:
+ * {@link DataPermissionRule} 工厂接口
+ * 作为 {@link DataPermissionRule} 的容器,提供管理能力
+ *
+ * @author 芋道源码
  */
 public interface DataPermissionRuleFactory {
+
+    /**
+     * 获得所有数据权限规则数组
+     *
+     * @return 数据权限规则数组
+     */
+    List<DataPermissionRule> getDataPermissionRules();
+
+    /**
+     * 获得指定 Mapper 的数据权限规则数组
+     *
+     * @param mappedStatementId 指定 Mapper 的编号
+     * @return 数据权限规则数组
+     */
+    List<DataPermissionRule> getDataPermissionRule(String mappedStatementId);
+
 }

+ 10 - 2
yudao-framework/yudao-spring-boot-starter-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptorTest.java

@@ -1,12 +1,20 @@
 package cn.iocoder.yudao.framework.datapermission.core.interceptor;
 
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
 import org.junit.jupiter.api.Test;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
-public class DataPermissionInterceptorTest {
+public class DataPermissionInterceptorTest extends BaseMockitoUnitTest {
 
-    private final DataPermissionInterceptor interceptor = new DataPermissionInterceptor();
+    @InjectMocks
+    private DataPermissionInterceptor interceptor;
+
+    @Mock
+    private DataPermissionRuleFactory ruleFactory;
 
     @Test
     public void selectSingle() {

+ 1 - 1
yudao-framework/yudao-spring-boot-starter-mq/src/main/java/cn/iocoder/yudao/framework/mq/core/pubsub/AbstractChannelMessage.java

@@ -15,7 +15,7 @@ public abstract class AbstractChannelMessage extends AbstractRedisMessage {
      *
      * @return Channel
      */
-    @JsonIgnore // 避免序列化
+    @JsonIgnore // 避免序列化。原因是,Redis 发布 Channel 消息的时候,已经会指定。
     public abstract String getChannel();
 
 }