Răsfoiți Sursa

初始化数据权限模块

YunaiV 3 ani în urmă
părinte
comite
814c2db65c

+ 1 - 1
yudao-dependencies/pom.xml

@@ -24,7 +24,7 @@
         <!-- DB 相关 -->
         <mysql.version>5.1.46</mysql.version>
         <druid.version>1.2.4</druid.version>
-        <mybatis-plus.version>3.4.2</mybatis-plus.version>
+        <mybatis-plus.version>3.4.3.4</mybatis-plus.version>
         <dynamic-datasource.version>3.3.2</dynamic-datasource.version>
         <redisson.version>3.16.3</redisson.version>
         <!-- Config 配置中心相关 -->

+ 1 - 1
yudao-framework/pom.xml

@@ -33,7 +33,7 @@
         <module>yudao-spring-boot-starter-biz-weixin</module>
         <module>yudao-spring-boot-starter-extension</module>
         <module>yudao-spring-boot-starter-tenant</module>
-        <module>yudao-spring-boot-starter-datascope</module>
+        <module>yudao-spring-boot-starter-data-permission</module>
     </modules>
 
     <artifactId>yudao-framework</artifactId>

+ 1 - 29
yudao-framework/yudao-spring-boot-starter-datascope/pom.xml → yudao-framework/yudao-spring-boot-starter-data-permission/pom.xml

@@ -8,7 +8,7 @@
         <version>${revision}</version>
     </parent>
     <modelVersion>4.0.0</modelVersion>
-    <artifactId>yudao-spring-boot-starter-datascope</artifactId>
+    <artifactId>yudao-spring-boot-starter-data-permission</artifactId>
     <packaging>jar</packaging>
 
     <name>${artifactId}</name>
@@ -21,40 +21,12 @@
             <artifactId>yudao-common</artifactId>
         </dependency>
 
-        <!-- Web 相关 -->
-        <dependency>
-            <groupId>org.springframework.boot</groupId>
-            <artifactId>spring-boot-starter-web</artifactId>
-        </dependency>
-
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-security</artifactId>
-        </dependency>
-
         <!-- DB 相关 -->
         <dependency>
             <groupId>cn.iocoder.boot</groupId>
             <artifactId>yudao-spring-boot-starter-mybatis</artifactId>
         </dependency>
 
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-redis</artifactId>
-        </dependency>
-
-        <!-- Job 定时任务相关 -->
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-job</artifactId>
-        </dependency>
-
-        <!-- 消息队列相关 -->
-        <dependency>
-            <groupId>cn.iocoder.boot</groupId>
-            <artifactId>yudao-spring-boot-starter-mq</artifactId>
-        </dependency>
-
         <!-- Test 测试相关 -->
         <dependency>
             <groupId>org.springframework.boot</groupId>

+ 400 - 0
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datascope/core/interceptor/DataPermissionInterceptor.java

@@ -0,0 +1,400 @@
+package cn.iocoder.yudao.framework.datascope.core.interceptor;
+
+import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
+import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
+import net.sf.jsqlparser.expression.*;
+import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
+import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
+import net.sf.jsqlparser.expression.operators.relational.*;
+import net.sf.jsqlparser.schema.Column;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.select.*;
+import net.sf.jsqlparser.statement.update.Update;
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.statement.StatementHandler;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
+import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.RowBounds;
+
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.util.Collection;
+import java.util.Deque;
+import java.util.LinkedList;
+import java.util.List;
+
+public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
+
+//    private TenantLineHandler tenantLineHandler;
+
+    @Override
+    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
+        mpBs.sql(parserSingle(mpBs.sql(), null));
+    }
+
+    @Override
+    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
+        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
+        MappedStatement ms = mpSh.mappedStatement();
+        SqlCommandType sct = ms.getSqlCommandType();
+        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句
+            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
+            mpBs.sql(parserMulti(mpBs.sql(), null));
+        }
+    }
+
+    @Override
+    protected void processSelect(Select select, int index, String sql, Object obj) {
+        processSelectBody(select.getSelectBody());
+        List<WithItem> withItemsList = select.getWithItemsList();
+        if (!CollectionUtils.isEmpty(withItemsList)) {
+            withItemsList.forEach(this::processSelectBody);
+        }
+    }
+
+    protected void processSelectBody(SelectBody selectBody) {
+        if (selectBody == null) {
+            return;
+        }
+        if (selectBody instanceof PlainSelect) {
+            processPlainSelect((PlainSelect) selectBody);
+        } else if (selectBody instanceof WithItem) {
+            WithItem withItem = (WithItem) selectBody;
+            processSelectBody(withItem.getSubSelect().getSelectBody());
+        } else {
+            SetOperationList operationList = (SetOperationList) selectBody;
+            List<SelectBody> selectBodys = operationList.getSelects();
+            if (CollectionUtils.isNotEmpty(selectBodys)) {
+                selectBodys.forEach(this::processSelectBody);
+            }
+        }
+    }
+
+    /**
+     * update 语句处理
+     */
+    @Override
+    protected void processUpdate(Update update, int index, String sql, Object obj) {
+        final Table table = update.getTable();
+        if (ignoreTable(table.getName())) {
+            // 过滤退出执行
+            return;
+        }
+        update.setWhere(this.andExpression(table, update.getWhere()));
+    }
+
+    /**
+     * delete 语句处理
+     */
+    @Override
+    protected void processDelete(Delete delete, int index, String sql, Object obj) {
+        if (ignoreTable(delete.getTable().getName())) {
+            // 过滤退出执行
+            return;
+        }
+        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+    }
+
+    /**
+     * delete update 语句 where 处理
+     */
+    protected BinaryExpression andExpression(Table table, Expression where) {
+        //获得where条件表达式
+        EqualsTo equalsTo = new EqualsTo();
+        equalsTo.setLeftExpression(this.getAliasColumn(table));
+        equalsTo.setRightExpression(getTenantId());
+        if (null != where) {
+            if (where instanceof OrExpression) {
+                return new AndExpression(equalsTo, new Parenthesis(where));
+            } else {
+                return new AndExpression(equalsTo, where);
+            }
+        }
+        return equalsTo;
+    }
+
+    /**
+     * 追加 SelectItem
+     *
+     * @param selectItems SelectItem
+     */
+    protected void appendSelectItem(List<SelectItem> selectItems) {
+        if (CollectionUtils.isEmpty(selectItems)) return;
+        if (selectItems.size() == 1) {
+            SelectItem item = selectItems.get(0);
+            if (item instanceof AllColumns || item instanceof AllTableColumns) return;
+        }
+        selectItems.add(new SelectExpressionItem(new Column(getTenantIdColumn())));
+    }
+
+    /**
+     * 处理 PlainSelect
+     */
+    protected void processPlainSelect(PlainSelect plainSelect) {
+        FromItem fromItem = plainSelect.getFromItem();
+        Expression where = plainSelect.getWhere();
+        processWhereSubSelect(where);
+        if (fromItem instanceof Table) {
+            Table fromTable = (Table) fromItem;
+            if (!ignoreTable(fromTable.getName())) {
+                //#1186 github
+                plainSelect.setWhere(builderExpression(where, fromTable));
+            }
+        } else {
+            processFromItem(fromItem);
+        }
+        //#3087 github
+        List<SelectItem> selectItems = plainSelect.getSelectItems();
+        if (CollectionUtils.isNotEmpty(selectItems)) {
+            selectItems.forEach(this::processSelectItem);
+        }
+        List<Join> joins = plainSelect.getJoins();
+        if (CollectionUtils.isNotEmpty(joins)) {
+            processJoins(joins);
+        }
+    }
+
+    /**
+     * 处理where条件内的子查询
+     * <p>
+     * 支持如下:
+     * 1. in
+     * 2. =
+     * 3. >
+     * 4. <
+     * 5. >=
+     * 6. <=
+     * 7. <>
+     * 8. EXISTS
+     * 9. NOT EXISTS
+     * <p>
+     * 前提条件:
+     * 1. 子查询必须放在小括号中
+     * 2. 子查询一般放在比较操作符的右边
+     *
+     * @param where where 条件
+     */
+    protected void processWhereSubSelect(Expression where) {
+        if (where == null) {
+            return;
+        }
+        if (where instanceof FromItem) {
+            processFromItem((FromItem) where);
+            return;
+        }
+        if (where.toString().indexOf("SELECT") > 0) {
+            // 有子查询
+            if (where instanceof BinaryExpression) {
+                // 比较符号 , and , or , 等等
+                BinaryExpression expression = (BinaryExpression) where;
+                processWhereSubSelect(expression.getLeftExpression());
+                processWhereSubSelect(expression.getRightExpression());
+            } else if (where instanceof InExpression) {
+                // in
+                InExpression expression = (InExpression) where;
+                ItemsList itemsList = expression.getRightItemsList();
+                if (itemsList instanceof SubSelect) {
+                    processSelectBody(((SubSelect) itemsList).getSelectBody());
+                }
+            } else if (where instanceof ExistsExpression) {
+                // exists
+                ExistsExpression expression = (ExistsExpression) where;
+                processWhereSubSelect(expression.getRightExpression());
+            } else if (where instanceof NotExpression) {
+                // not exists
+                NotExpression expression = (NotExpression) where;
+                processWhereSubSelect(expression.getExpression());
+            } else if (where instanceof Parenthesis) {
+                Parenthesis expression = (Parenthesis) where;
+                processWhereSubSelect(expression.getExpression());
+            }
+        }
+    }
+
+    protected void processSelectItem(SelectItem selectItem) {
+        if (selectItem instanceof SelectExpressionItem) {
+            SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
+            if (selectExpressionItem.getExpression() instanceof SubSelect) {
+                processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
+            } else if (selectExpressionItem.getExpression() instanceof Function) {
+                processFunction((Function) selectExpressionItem.getExpression());
+            }
+        }
+    }
+
+    /**
+     * 处理函数
+     * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
+     * <p> fixed gitee pulls/141</p>
+     *
+     * @param function
+     */
+    protected void processFunction(Function function) {
+        ExpressionList parameters = function.getParameters();
+        if (parameters != null) {
+            parameters.getExpressions().forEach(expression -> {
+                if (expression instanceof SubSelect) {
+                    processSelectBody(((SubSelect) expression).getSelectBody());
+                } else if (expression instanceof Function) {
+                    processFunction((Function) expression);
+                }
+            });
+        }
+    }
+
+    /**
+     * 处理子查询等
+     */
+    protected void processFromItem(FromItem fromItem) {
+        if (fromItem instanceof SubJoin) {
+            SubJoin subJoin = (SubJoin) fromItem;
+            if (subJoin.getJoinList() != null) {
+                processJoins(subJoin.getJoinList());
+            }
+            if (subJoin.getLeft() != null) {
+                processFromItem(subJoin.getLeft());
+            }
+        } else if (fromItem instanceof SubSelect) {
+            SubSelect subSelect = (SubSelect) fromItem;
+            if (subSelect.getSelectBody() != null) {
+                processSelectBody(subSelect.getSelectBody());
+            }
+        } else if (fromItem instanceof ValuesList) {
+            logger.debug("Perform a subquery, if you do not give us feedback");
+        } else if (fromItem instanceof LateralSubSelect) {
+            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
+            if (lateralSubSelect.getSubSelect() != null) {
+                SubSelect subSelect = lateralSubSelect.getSubSelect();
+                if (subSelect.getSelectBody() != null) {
+                    processSelectBody(subSelect.getSelectBody());
+                }
+            }
+        }
+    }
+
+    /**
+     * 处理 joins
+     *
+     * @param joins join 集合
+     */
+    private void processJoins(List<Join> joins) {
+        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
+        Deque<Table> tables = new LinkedList<>();
+        for (Join join : joins) {
+            // 处理 on 表达式
+            FromItem fromItem = join.getRightItem();
+            if (fromItem instanceof Table) {
+                Table fromTable = (Table) fromItem;
+                // 获取 join 尾缀的 on 表达式列表
+                Collection<Expression> originOnExpressions = join.getOnExpressions();
+                // 正常 join on 表达式只有一个,立刻处理
+                if (originOnExpressions.size() == 1) {
+                    processJoin(join);
+                    continue;
+                }
+                // 当前表是否忽略
+                boolean needIgnore = ignoreTable(fromTable.getName());
+                // 表名压栈,忽略的表压入 null,以便后续不处理
+                tables.push(needIgnore ? null : fromTable);
+                // 尾缀多个 on 表达式的时候统一处理
+                if (originOnExpressions.size() > 1) {
+                    Collection<Expression> onExpressions = new LinkedList<>();
+                    for (Expression originOnExpression : originOnExpressions) {
+                        Table currentTable = tables.poll();
+                        if (currentTable == null) {
+                            onExpressions.add(originOnExpression);
+                        } else {
+                            onExpressions.add(builderExpression(originOnExpression, currentTable));
+                        }
+                    }
+                    join.setOnExpressions(onExpressions);
+                }
+            } else {
+                // 处理右边连接的子表达式
+                processFromItem(fromItem);
+            }
+        }
+    }
+
+    /**
+     * 处理联接语句
+     */
+    protected void processJoin(Join join) {
+        if (join.getRightItem() instanceof Table) {
+            Table fromTable = (Table) join.getRightItem();
+            if (ignoreTable(fromTable.getName())) {
+                // 过滤退出执行
+                return;
+            }
+            // 走到这里说明 on 表达式肯定只有一个
+            Collection<Expression> originOnExpressions = join.getOnExpressions();
+            List<Expression> onExpressions = new LinkedList<>();
+            onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
+            join.setOnExpressions(onExpressions);
+        }
+    }
+
+    /**
+     * 处理条件
+     */
+    protected Expression builderExpression(Expression currentExpression, Table table) {
+        EqualsTo equalsTo = new EqualsTo();
+        equalsTo.setLeftExpression(this.getAliasColumn(table));
+        equalsTo.setRightExpression(getTenantId());
+        if (currentExpression == null) {
+            return equalsTo;
+        }
+        if (currentExpression instanceof OrExpression) {
+            return new AndExpression(new Parenthesis(currentExpression), equalsTo);
+        } else {
+            return new AndExpression(currentExpression, equalsTo);
+        }
+    }
+
+    /**
+     * 租户字段别名设置
+     * <p>tenantId 或 tableAlias.tenantId</p>
+     *
+     * @param table 表对象
+     * @return 字段
+     */
+    protected Column getAliasColumn(Table table) {
+        StringBuilder column = new StringBuilder();
+        if (table.getAlias() != null) {
+            column.append(table.getAlias().getName()).append(StringPool.DOT);
+        }
+        column.append(getTenantIdColumn());
+        return new Column(column.toString());
+    }
+
+//    @Override
+//    public void setProperties(Properties properties) {
+//        PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
+//                ClassUtils::newInstance, this::setTenantLineHandler);
+//    }
+
+    // TODO 芋艿:未实现
+
+    private boolean ignoreTable(String tableName) {
+        return false;
+    }
+
+    private String getTenantIdColumn() {
+        return "dept_id";
+    }
+
+    private Expression getTenantId() {
+        return new LongValue(1L);
+    }
+
+}

+ 1 - 0
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datascope/core/package-info.java

@@ -0,0 +1 @@
+package cn.iocoder.yudao.framework.datascope.core;

+ 0 - 0
yudao-framework/yudao-spring-boot-starter-datascope/src/main/java/cn/iocoder/yudao/framework/datascope/package-info.java → yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datascope/package-info.java


+ 32 - 0
yudao-framework/yudao-spring-boot-starter-data-permission/src/test/java/cn/iocoder/yudao/framework/datascope/core/interceptor/DataPermissionInterceptorTest.java

@@ -0,0 +1,32 @@
+package cn.iocoder.yudao.framework.datascope.core.interceptor;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class DataPermissionInterceptorTest {
+
+    private final DataPermissionInterceptor interceptor = new DataPermissionInterceptor();
+
+    @Test
+    void selectSingle() {
+        // 单表
+        assertSql("select * from entity where id = ?",
+                "SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
+
+        assertSql("select * from entity where id = ? or name = ?",
+                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+
+        assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
+                "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+
+        /* not */
+        assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
+                "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
+    }
+
+    private void assertSql(String sql, String targetSql) {
+        assertThat(interceptor.parserSingle(sql, null)).isEqualTo(targetSql);
+    }
+
+}

+ 1 - 1
yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/mapper/BaseMapperX.java

@@ -33,7 +33,7 @@ public interface BaseMapperX<T> extends BaseMapper<T> {
     }
 
     default Integer selectCount(String field, Object value) {
-        return selectCount(new QueryWrapper<T>().eq(field, value));
+        return selectCount(new QueryWrapper<T>().eq(field, value)).intValue();
     }
 
     default List<T> selectList() {