Browse Source

通过 MyBatis Plus 数据权限的单测

YunaiV 3 years ago
parent
commit
0f792c64e7

+ 2 - 1
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java

@@ -330,6 +330,7 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
 //                boolean needIgnore = ignoreTable(fromTable.getName());
 //                // 表名压栈,忽略的表压入 null,以便后续不处理
 //                tables.push(needIgnore ? null : fromTable);
+                tables.push(fromTable);
                 // 尾缀多个 on 表达式的时候统一处理
                 if (originOnExpressions.size() > 1) {
                     Collection<Expression> onExpressions = new LinkedList<>();
@@ -457,7 +458,7 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
      *
      * @author 芋道源码
      */
-    private static final class ContextHolder {
+    static final class ContextHolder {
 
         /**
          * 该 {@link MappedStatement} 对应的规则

+ 0 - 44
yudao-framework/yudao-spring-boot-starter-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptorTest.java

@@ -1,44 +0,0 @@
-package cn.iocoder.yudao.framework.datapermission.core.interceptor;
-
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
-import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
-import org.junit.jupiter.api.Test;
-import org.mockito.InjectMocks;
-import org.mockito.Mock;
-
-import static org.assertj.core.api.Assertions.assertThat;
-
-public class DataPermissionInterceptorTest extends BaseMockitoUnitTest {
-
-    @InjectMocks
-    private DataPermissionInterceptor interceptor;
-
-    @Mock
-    private DataPermissionRuleFactory ruleFactory;
-
-    @Test
-    public void selectSingle() {
-        // 单表
-        assertSql("select * from entity where id = ?",
-                "SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
-
-        assertSql("select * from entity where id = ? or name = ?",
-                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
-
-        assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
-                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
-
-        /* not */
-        assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
-                "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
-    }
-
-    private void assertSql(String sql, String targetSql) {
-        assertThat(interceptor.parserSingle(sql, null)).isEqualTo(targetSql);
-    }
-
-    public static void main(String[] args) {
-        System.out.println("123");
-    }
-
-}

+ 266 - 0
yudao-framework/yudao-spring-boot-starter-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptorTest2.java

@@ -0,0 +1,266 @@
+package cn.iocoder.yudao.framework.datapermission.core.interceptor;
+
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
+import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
+import net.sf.jsqlparser.expression.Alias;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
+import net.sf.jsqlparser.schema.Column;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+
+import java.util.Collections;
+import java.util.Set;
+
+import static cn.iocoder.yudao.framework.common.util.collection.SetUtils.asSet;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * {@link DataPermissionInterceptor} 的单元测试
+ * 主要复用了 MyBatis Plus 的 TenantLineInnerInterceptorTest 的单元测试
+ * 不过它的单元测试不是很规范,考虑到是复用的,所以暂时不进行修改~
+ *
+ * @author 芋道源码
+ */
+public class DataPermissionInterceptorTest2 extends BaseMockitoUnitTest {
+
+    @InjectMocks
+    private DataPermissionInterceptor interceptor;
+
+    @Mock
+    private DataPermissionRuleFactory ruleFactory;
+
+    @BeforeEach
+    public void setUp() {
+        // 租户的数据权限规则
+        DataPermissionRule tenantRule = new DataPermissionRule() {
+
+            private static final String COLUMN = "tenant_id";
+
+            @Override
+            public Set<String> getTableNames() {
+                return asSet("entity", "entity1", "entity2", "t1", "t2");
+            }
+
+            @Override
+            public Expression getExpression(String tableName, Alias tableAlias) {
+                Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
+                LongValue value = new LongValue(1L);
+                return new EqualsTo(column, value);
+            }
+
+        };
+        // 设置到上下文,保证
+        DataPermissionInterceptor.ContextHolder.init(Collections.singletonList(tenantRule));
+    }
+
+    @Test
+    void delete() {
+        assertSql("delete from entity where id = ?",
+                "DELETE FROM entity WHERE id = ? AND tenant_id = 1");
+    }
+
+    @Test
+    void update() {
+        assertSql("update entity set name = ? where id = ?",
+                "UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1");
+    }
+
+    @Test
+    void selectSingle() {
+        // 单表
+        assertSql("select * from entity where id = ?",
+                "SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
+
+        assertSql("select * from entity where id = ? or name = ?",
+                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+
+        assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
+                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+
+        /* not */
+        assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
+                "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
+    }
+
+    @Test
+    void selectSubSelectIn() {
+        /* in */
+        assertSql("SELECT * FROM entity e WHERE e.id IN (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id IN (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在最前
+        assertSql("SELECT * FROM entity e WHERE e.id IN " +
+                        "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+                "SELECT * FROM entity e WHERE e.id IN " +
+                        "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+        // 在最后
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
+                        "(select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
+                        "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在中间
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
+                        "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+                "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
+                        "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectSubSelectEq() {
+        /* = */
+        assertSql("SELECT * FROM entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectSubSelectInnerNotEq() {
+        /* inner not = */
+        assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?))",
+                "SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1)) AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?) and e.id = ?)",
+                "SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ?) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectSubSelectExists() {
+        /* EXISTS */
+        assertSql("SELECT * FROM entity e WHERE EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+
+
+        /* NOT EXISTS */
+        assertSql("SELECT * FROM entity e WHERE NOT EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE NOT EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectSubSelect() {
+        /* >= */
+        assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+
+        /* <= */
+        assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+
+        /* <> */
+        assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
+                "SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectFromSelect() {
+        assertSql("SELECT * FROM (select e.id from entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?))",
+                "SELECT * FROM (SELECT e.id FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1)");
+    }
+
+    @Test
+    void selectBodySubSelect() {
+        assertSql("select t1.col1,(select t2.col2 from t2 t2 where t1.col1=t2.col1) from t1 t1",
+                "SELECT t1.col1, (SELECT t2.col2 FROM t2 t2 WHERE t1.col1 = t2.col1 AND t2.tenant_id = 1) FROM t1 t1 WHERE t1.tenant_id = 1");
+    }
+
+    @Test
+    void selectLeftJoin() {
+        // left join
+        assertSql("SELECT * FROM entity e " +
+                        "left join entity1 e1 on e1.id = e.id " +
+                        "WHERE e.id = ? OR e.name = ?",
+                "SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                        "left join entity1 e1 on e1.id = e.id " +
+                        "WHERE (e.id = ? OR e.name = ?)",
+                "SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectRightJoin() {
+        // right join
+        assertSql("SELECT * FROM entity e " +
+                        "right join entity1 e1 on e1.id = e.id",
+                "SELECT * FROM entity e " +
+                        "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                        "right join entity1 e1 on e1.id = e.id " +
+                        "WHERE e.id = ? OR e.name = ?",
+                "SELECT * FROM entity e " +
+                        "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectLeftJoinMultipleTrailingOn() {
+        // 多个 on 尾缀的
+        assertSql("SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 " +
+                        "LEFT JOIN entity2 e2 ON e2.id = e1.id " +
+                        "ON e1.id = e.id " +
+                        "WHERE (e.id = ? OR e.NAME = ?)",
+                "SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 " +
+                        "LEFT JOIN entity2 e2 ON e2.id = e1.id AND e2.tenant_id = 1 " +
+                        "ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 " +
+                        "LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
+                        "ON e1.id = e.id " +
+                        "WHERE (e.id = ? OR e.NAME = ?)",
+                "SELECT * FROM entity e " +
+                        "LEFT JOIN entity1 e1 " +
+                        "LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
+                        "ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
+    }
+
+    @Test
+    void selectInnerJoin() {
+        // inner join
+        assertSql("SELECT * FROM entity e " +
+                        "inner join entity1 e1 on e1.id = e.id " +
+                        "WHERE e.id = ? OR e.name = ?",
+                "SELECT * FROM entity e " +
+                        "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                        "inner join entity1 e1 on e1.id = e.id " +
+                        "WHERE (e.id = ? OR e.name = ?)",
+                "SELECT * FROM entity e " +
+                        "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                        "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+
+        // 垃圾 inner join todo
+//        assertSql("SELECT * FROM entity,entity1 " +
+//                "WHERE entity.id = entity1.id",
+//            "SELECT * FROM entity e " +
+//                "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+//                "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+    }
+
+
+    @Test
+    void selectWithAs() {
+        assertSql("with with_as_A as (select * from entity) select * from with_as_A",
+                "WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A");
+    }
+
+    private void assertSql(String sql, String targetSql) {
+        assertEquals(targetSql, interceptor.parserSingle(sql, null));
+    }
+
+}

+ 14 - 0
yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java

@@ -7,6 +7,8 @@ import com.baomidou.mybatisplus.core.metadata.OrderItem;
 import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
+import net.sf.jsqlparser.expression.Alias;
+import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 
 import java.util.ArrayList;
@@ -67,4 +69,16 @@ public class MyBatisUtils {
         return tableName;
     }
 
+    /**
+     * 构建 Column 对象
+     *
+     * @param tableName 表名
+     * @param tableAlias 别名
+     * @param column 字段名
+     * @return Column 对象
+     */
+    public static Column buildColumn(String tableName, Alias tableAlias, String column) {
+        return new Column(tableAlias != null ? tableAlias.getName() + "." + column : column);
+    }
+
 }