tables = processSubJoin((SubJoin) fromItem);
- mainTables.addAll(tables);
- } else {
- // 处理下 fromItem
- processOtherFromItem(fromItem);
- }
- return mainTables;
- }
-
- /**
- * 处理where条件内的子查询
- *
- * 支持如下:
- * 1. in
- * 2. =
- * 3. >
- * 4. <
- * 5. >=
- * 6. <=
- * 7. <>
- * 8. EXISTS
- * 9. NOT EXISTS
- *
- * 前提条件:
- * 1. 子查询必须放在小括号中
- * 2. 子查询一般放在比较操作符的右边
- *
- * @param where where 条件
- */
- protected void processWhereSubSelect(Expression where) {
- if (where == null) {
- return;
- }
- if (where instanceof FromItem) {
- processOtherFromItem((FromItem) where);
- return;
- }
- if (where.toString().indexOf("SELECT") > 0) {
- // 有子查询
- if (where instanceof BinaryExpression) {
- // 比较符号 , and , or , 等等
- BinaryExpression expression = (BinaryExpression) where;
- processWhereSubSelect(expression.getLeftExpression());
- processWhereSubSelect(expression.getRightExpression());
- } else if (where instanceof InExpression) {
- // in
- InExpression expression = (InExpression) where;
- Expression inExpression = expression.getRightExpression();
- if (inExpression instanceof SubSelect) {
- processSelectBody(((SubSelect) inExpression).getSelectBody());
- }
- } else if (where instanceof ExistsExpression) {
- // exists
- ExistsExpression expression = (ExistsExpression) where;
- processWhereSubSelect(expression.getRightExpression());
- } else if (where instanceof NotExpression) {
- // not exists
- NotExpression expression = (NotExpression) where;
- processWhereSubSelect(expression.getExpression());
- } else if (where instanceof Parenthesis) {
- Parenthesis expression = (Parenthesis) where;
- processWhereSubSelect(expression.getExpression());
- }
- }
- }
-
- protected void processSelectItem(SelectItem selectItem) {
- if (selectItem instanceof SelectExpressionItem) {
- SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
- if (selectExpressionItem.getExpression() instanceof SubSelect) {
- processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
- } else if (selectExpressionItem.getExpression() instanceof Function) {
- processFunction((Function) selectExpressionItem.getExpression());
- }
- }
- }
-
- /**
- * 处理函数
- *
支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)
- *
fixed gitee pulls/141
- *
- * @param function
- */
- protected void processFunction(Function function) {
- ExpressionList parameters = function.getParameters();
- if (parameters != null) {
- parameters.getExpressions().forEach(expression -> {
- if (expression instanceof SubSelect) {
- processSelectBody(((SubSelect) expression).getSelectBody());
- } else if (expression instanceof Function) {
- processFunction((Function) expression);
- }
- });
- }
- }
-
- /**
- * 处理子查询等
- */
- protected void processOtherFromItem(FromItem fromItem) {
- // 去除括号
- while (fromItem instanceof ParenthesisFromItem) {
- fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
- }
-
- if (fromItem instanceof SubSelect) {
- SubSelect subSelect = (SubSelect) fromItem;
- if (subSelect.getSelectBody() != null) {
- processSelectBody(subSelect.getSelectBody());
- }
- } else if (fromItem instanceof ValuesList) {
- logger.debug("Perform a subQuery, if you do not give us feedback");
- } else if (fromItem instanceof LateralSubSelect) {
- LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
- if (lateralSubSelect.getSubSelect() != null) {
- SubSelect subSelect = lateralSubSelect.getSubSelect();
- if (subSelect.getSelectBody() != null) {
- processSelectBody(subSelect.getSelectBody());
- }
- }
- }
- }
-
- /**
- * 处理 sub join
- *
- * @param subJoin subJoin
- * @return Table subJoin 中的主表
- */
- private List processSubJoin(SubJoin subJoin) {
- List mainTables = new ArrayList<>();
- if (subJoin.getJoinList() != null) {
- List list = processFromItem(subJoin.getLeft());
- mainTables.addAll(list);
- mainTables = processJoins(mainTables, subJoin.getJoinList());
- }
- return mainTables;
- }
-
- /**
- * 处理 joins
- *
- * @param mainTables 可以为 null
- * @param joins join 集合
- * @return List 右连接查询的 Table 列表
- */
- private List processJoins(List mainTables, List joins) {
- // join 表达式中最终的主表
- Table mainTable = null;
- // 当前 join 的左表
- Table leftTable = null;
-
- if (mainTables == null) {
- mainTables = new ArrayList<>();
- } else if (mainTables.size() == 1) {
- mainTable = mainTables.get(0);
- leftTable = mainTable;
- }
-
- //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
- Deque> onTableDeque = new LinkedList<>();
- for (Join join : joins) {
- // 处理 on 表达式
- FromItem joinItem = join.getRightItem();
-
- // 获取当前 join 的表,subJoint 可以看作是一张表
- List joinTables = null;
- if (joinItem instanceof Table) {
- joinTables = new ArrayList<>();
- joinTables.add((Table) joinItem);
- } else if (joinItem instanceof SubJoin) {
- joinTables = processSubJoin((SubJoin) joinItem);
- }
-
- if (joinTables != null) {
-
- // 如果是隐式内连接
- if (join.isSimple()) {
- mainTables.addAll(joinTables);
- continue;
- }
-
- // 当前表是否忽略
- Table joinTable = joinTables.get(0);
-
- List onTables = null;
- // 如果不要忽略,且是右连接,则记录下当前表
- if (join.isRight()) {
- mainTable = joinTable;
- if (leftTable != null) {
- onTables = Collections.singletonList(leftTable);
- }
- } else if (join.isLeft()) {
- onTables = Collections.singletonList(joinTable);
- } else if (join.isInner()) {
- if (mainTable == null) {
- onTables = Collections.singletonList(joinTable);
- } else {
- onTables = Arrays.asList(mainTable, joinTable);
- }
- mainTable = null;
- }
-
- mainTables = new ArrayList<>();
- if (mainTable != null) {
- mainTables.add(mainTable);
- }
-
- // 获取 join 尾缀的 on 表达式列表
- Collection originOnExpressions = join.getOnExpressions();
- // 正常 join on 表达式只有一个,立刻处理
- if (originOnExpressions.size() == 1 && onTables != null) {
- List onExpressions = new LinkedList<>();
- onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
- join.setOnExpressions(onExpressions);
- leftTable = joinTable;
- continue;
- }
- // 表名压栈,忽略的表压入 null,以便后续不处理
- onTableDeque.push(onTables);
- // 尾缀多个 on 表达式的时候统一处理
- if (originOnExpressions.size() > 1) {
- Collection onExpressions = new LinkedList<>();
- for (Expression originOnExpression : originOnExpressions) {
- List currentTableList = onTableDeque.poll();
- if (CollectionUtils.isEmpty(currentTableList)) {
- onExpressions.add(originOnExpression);
- } else {
- onExpressions.add(builderExpression(originOnExpression, currentTableList));
- }
- }
- join.setOnExpressions(onExpressions);
- }
- leftTable = joinTable;
- } else {
- processOtherFromItem(joinItem);
- leftTable = null;
- }
- }
-
- return mainTables;
- }
-
- // ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
-
- /**
- * 处理条件
- *
- * @param currentExpression 当前 where 条件
- * @param table 单个表
- */
- protected Expression builderExpression(Expression currentExpression, Table table) {
- return this.builderExpression(currentExpression, Collections.singletonList(table));
- }
-
- /**
- * 处理条件
- *
- * @param currentExpression 当前 where 条件
- * @param tables 多个表
- */
- protected Expression builderExpression(Expression currentExpression, List tables) {
- // 没有表需要处理直接返回
- if (CollectionUtils.isEmpty(tables)) {
- return currentExpression;
- }
-
- // 第一步,获得 Table 对应的数据权限条件
- Expression dataPermissionExpression = null;
- for (Table table : tables) {
- // 构建每个表的权限 Expression 条件
- Expression expression = buildDataPermissionExpression(table);
- if (expression == null) {
- continue;
- }
- // 合并到 dataPermissionExpression 中
- dataPermissionExpression = dataPermissionExpression == null ? expression
- : new AndExpression(dataPermissionExpression, expression);
- }
-
- // 第二步,合并多个 Expression 条件
- if (dataPermissionExpression == null) {
- return currentExpression;
- }
- if (currentExpression == null) {
- return dataPermissionExpression;
- }
- // ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression
- if (currentExpression instanceof OrExpression) {
- return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
- }
- // ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression
- return new AndExpression(currentExpression, dataPermissionExpression);
- }
-
- /**
- * 构建指定表的数据权限的 Expression 过滤条件
- *
- * @param table 表
- * @return Expression 过滤条件
- */
- private Expression buildDataPermissionExpression(Table table) {
- // 生成条件
- Expression allExpression = null;
- for (DataPermissionRule rule : ContextHolder.getRules()) {
- // 判断表名是否匹配
- String tableName = MyBatisUtils.getTableName(table);
- if (!rule.getTableNames().contains(tableName)) {
- continue;
- }
- // 如果有匹配的规则,说明可重写。
- // 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
- // 这样导致第一次无 value,被标记成无需重写;但是第二次有 value,此时会需要重写。
- ContextHolder.setRewrite(true);
-
- // 单条规则的条件
- Expression oneExpress = rule.getExpression(tableName, table.getAlias());
- if (oneExpress == null){
- continue;
- }
- // 拼接到 allExpression 中
- allExpression = allExpression == null ? oneExpress
- : new AndExpression(allExpression, oneExpress);
- }
-
- return allExpression;
- }
-
- /**
- * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
- *
- * @param ms MappedStatement
- */
- private void addMappedStatementCache(MappedStatement ms) {
- if (ContextHolder.getRewrite()) {
- return;
- }
- // 无重写,进行添加
- mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
- }
-
- /**
- * SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
- *
- * @author 芋道源码
- */
- static final class ContextHolder {
-
- /**
- * 该 {@link MappedStatement} 对应的规则
- */
- private static final ThreadLocal> RULES = ThreadLocal.withInitial(Collections::emptyList);
- /**
- * SQL 是否进行重写
- */
- private static final ThreadLocal REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
-
- public static void init(List rules) {
- RULES.set(rules);
- REWRITE.set(false);
- }
-
- public static void clear() {
- RULES.remove();
- REWRITE.remove();
- }
-
- public static boolean getRewrite() {
- return REWRITE.get();
- }
-
- public static void setRewrite(boolean rewrite) {
- REWRITE.set(rewrite);
- }
-
- public static List getRules() {
- return RULES.get();
- }
-
- }
-
- /**
- * {@link MappedStatement} 缓存
- * 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
- * 如果无效,则可以避免 SQL 的解析,加快速度
- *
- * @author 芋道源码
- */
- static final class MappedStatementCache {
-
- /**
- * 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
- *
- * value:{@link MappedStatement#getId()} 编号
- */
- @Getter
- private final Map, Set> noRewritableMappedStatements = new ConcurrentHashMap<>();
-
- /**
- * 判断是否无需重写
- * ps:虽然有点中文式英语,但是容易读懂即可
- *
- * @param ms MappedStatement
- * @param rules 数据权限规则数组
- * @return 是否无需重写
- */
- public boolean noRewritable(MappedStatement ms, List rules) {
- // 如果规则为空,说明无需重写
- if (CollUtil.isEmpty(rules)) {
- return true;
- }
- // 任一规则不在 noRewritableMap 中,则说明可能需要重写
- for (DataPermissionRule rule : rules) {
- Set mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
- if (!CollUtil.contains(mappedStatementIds, ms.getId())) {
- return false;
- }
- }
- return true;
- }
-
- /**
- * 添加无需重写的 MappedStatement
- *
- * @param ms MappedStatement
- * @param rules 数据权限规则数组
- */
- public void addNoRewritable(MappedStatement ms, List rules) {
- for (DataPermissionRule rule : rules) {
- Set mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
- if (CollUtil.isNotEmpty(mappedStatementIds)) {
- mappedStatementIds.add(ms.getId());
- } else {
- noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
- }
- }
- }
-
- /**
- * 清空缓存
- * 目前主要提供给单元测试
- */
- public void clear() {
- noRewritableMappedStatements.clear();
- }
-
- }
-
-}
diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandler.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandler.java
new file mode 100644
index 000000000..a2778734b
--- /dev/null
+++ b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandler.java
@@ -0,0 +1,57 @@
+package cn.iocoder.yudao.framework.datapermission.core.db;
+
+import cn.hutool.core.collection.CollUtil;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
+import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
+import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
+import lombok.RequiredArgsConstructor;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
+import net.sf.jsqlparser.schema.Table;
+
+import java.util.List;
+
+/**
+ * 基于 {@link DataPermissionRule} 的数据权限处理器
+ *
+ * 它的底层,是基于 MyBatis Plus 的 数据权限插件
+ * 核心原理:它会在 SQL 执行前拦截 SQL 语句,并根据用户权限动态添加权限相关的 SQL 片段。这样,只有用户有权限访问的数据才会被查询出来
+ *
+ * @author 芋道源码
+ */
+@RequiredArgsConstructor
+public class DataPermissionRuleHandler implements MultiDataPermissionHandler {
+
+ private final DataPermissionRuleFactory ruleFactory;
+
+ @Override
+ public Expression getSqlSegment(Table table, Expression where, String mappedStatementId) {
+ // 获得 Mapper 对应的数据权限的规则
+ List rules = ruleFactory.getDataPermissionRule(mappedStatementId);
+ if (CollUtil.isEmpty(rules)) {
+ return null;
+ }
+
+ // 生成条件
+ Expression allExpression = null;
+ for (DataPermissionRule rule : rules) {
+ // 判断表名是否匹配
+ String tableName = MyBatisUtils.getTableName(table);
+ if (!rule.getTableNames().contains(tableName)) {
+ continue;
+ }
+
+ // 单条规则的条件
+ Expression oneExpress = rule.getExpression(tableName, table.getAlias());
+ if (oneExpress == null) {
+ continue;
+ }
+ // 拼接到 allExpression 中
+ allExpression = allExpression == null ? oneExpress
+ : new AndExpression(allExpression, oneExpress);
+ }
+ return allExpression;
+ }
+
+}
diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/rule/dept/DeptDataPermissionRule.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/rule/dept/DeptDataPermissionRule.java
index d6041d387..af1a5a6fb 100644
--- a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/rule/dept/DeptDataPermissionRule.java
+++ b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/rule/dept/DeptDataPermissionRule.java
@@ -119,7 +119,7 @@ public class DeptDataPermissionRule implements DataPermissionRule {
// 情况二,即不能查看部门,又不能查看自己,则说明 100% 无权限
if (CollUtil.isEmpty(deptDataPermission.getDeptIds())
- && Boolean.FALSE.equals(deptDataPermission.getSelf())) {
+ && Boolean.FALSE.equals(deptDataPermission.getSelf())) {
return new EqualsTo(null, null); // WHERE null = null,可以保证返回的数据为空
}
@@ -156,7 +156,8 @@ public class DeptDataPermissionRule implements DataPermissionRule {
}
// 拼接条件
return new InExpression(MyBatisUtils.buildColumn(tableName, tableAlias, columnName),
- new ExpressionList(CollectionUtils.convertList(deptIds, LongValue::new)));
+ // Parenthesis 的目的,是提供 (1,2,3) 的 () 左右括号
+ new Parenthesis(new ExpressionList<>(CollectionUtils.convertList(deptIds, LongValue::new))));
}
private Expression buildUserExpression(String tableName, Alias tableAlias, Boolean self, Long userId) {
@@ -180,7 +181,7 @@ public class DeptDataPermissionRule implements DataPermissionRule {
public void addDeptColumn(Class extends BaseDO> entityClass, String columnName) {
String tableName = TableInfoHelper.getTableInfo(entityClass).getTableName();
- addDeptColumn(tableName, columnName);
+ addDeptColumn(tableName, columnName);
}
public void addDeptColumn(String tableName, String columnName) {
diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest.java
deleted file mode 100644
index 145360789..000000000
--- a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest.java
+++ /dev/null
@@ -1,190 +0,0 @@
-package cn.iocoder.yudao.framework.datapermission.core.db;
-
-import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
-import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
-import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
-import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
-import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
-import net.sf.jsqlparser.expression.Alias;
-import net.sf.jsqlparser.expression.Expression;
-import net.sf.jsqlparser.expression.LongValue;
-import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
-import net.sf.jsqlparser.schema.Column;
-import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.mockito.InjectMocks;
-import org.mockito.Mock;
-import org.mockito.MockedStatic;
-
-import java.sql.Connection;
-import java.util.*;
-
-import static java.util.Collections.singletonList;
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
-
-/**
- * {@link DataPermissionDatabaseInterceptor} 的单元测试
- * 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
- * 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
- * 以及在这个过程中,ContextHolder 和 MappedStatementCache
- *
- * @author 芋道源码
- */
-public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
-
- @InjectMocks
- private DataPermissionDatabaseInterceptor interceptor;
-
- @Mock
- private DataPermissionRuleFactory ruleFactory;
-
- @BeforeEach
- public void setUp() {
- // 清理上下文
- DataPermissionDatabaseInterceptor.ContextHolder.clear();
- // 清空缓存
- interceptor.getMappedStatementCache().clear();
- }
-
- @Test // 不存在规则,且不匹配
- public void testBeforeQuery_withoutRule() {
- try (MockedStatic pluginUtilsMock = mockStatic(PluginUtils.class)) {
- // 准备参数
- MappedStatement mappedStatement = mock(MappedStatement.class);
- BoundSql boundSql = mock(BoundSql.class);
-
- // 调用
- interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
- // 断言
- pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
- }
- }
-
- @Test // 存在规则,且不匹配
- public void testBeforeQuery_withMatchRule() {
- try (MockedStatic pluginUtilsMock = mockStatic(PluginUtils.class)) {
- // 准备参数
- MappedStatement mappedStatement = mock(MappedStatement.class);
- BoundSql boundSql = mock(BoundSql.class);
- // mock 方法(数据权限)
- when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
- .thenReturn(singletonList(new DeptDataPermissionRule()));
- // mock 方法(MPBoundSql)
- PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
- pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
- // mock 方法(SQL)
- String sql = "select * from t_user where id = 1";
- when(mpBs.sql()).thenReturn(sql);
- // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
-
- // 调用
- interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
- // 断言
- verify(mpBs, times(1)).sql(
- eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
- // 断言缓存
- assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
- }
- }
-
- @Test // 存在规则,但不匹配
- public void testBeforeQuery_withoutMatchRule() {
- try (MockedStatic pluginUtilsMock = mockStatic(PluginUtils.class)) {
- // 准备参数
- MappedStatement mappedStatement = mock(MappedStatement.class);
- BoundSql boundSql = mock(BoundSql.class);
- // mock 方法(数据权限)
- when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
- .thenReturn(singletonList(new DeptDataPermissionRule()));
- // mock 方法(MPBoundSql)
- PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
- pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
- // mock 方法(SQL)
- String sql = "select * from t_role where id = 1";
- when(mpBs.sql()).thenReturn(sql);
- // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
-
- // 调用
- interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
- // 断言
- verify(mpBs, times(1)).sql(
- eq("SELECT * FROM t_role WHERE id = 1"));
- // 断言缓存
- assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
- }
- }
-
- @Test
- public void testAddNoRewritable() {
- // 准备参数
- MappedStatement ms = mock(MappedStatement.class);
- List rules = singletonList(new DeptDataPermissionRule());
- // mock 方法
- when(ms.getId()).thenReturn("selectById");
-
- // 调用
- interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
- // 断言
- Map, Set> noRewritableMappedStatements =
- interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
- assertEquals(1, noRewritableMappedStatements.size());
- assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
- }
-
- @Test
- public void testNoRewritable() {
- // 准备参数
- MappedStatement ms = mock(MappedStatement.class);
- // mock 方法
- when(ms.getId()).thenReturn("selectById");
- // mock 数据
- List rules = singletonList(new DeptDataPermissionRule());
- interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
-
- // 场景一,rules 为空
- assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
- // 场景二,rules 非空,可重写
- assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
- // 场景三,rule 非空,不可重写
- assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
- }
-
- private static class DeptDataPermissionRule implements DataPermissionRule {
-
- private static final String COLUMN = "dept_id";
-
- @Override
- public Set getTableNames() {
- return SetUtils.asSet("t_user");
- }
-
- @Override
- public Expression getExpression(String tableName, Alias tableAlias) {
- Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
- LongValue value = new LongValue(100L);
- return new EqualsTo(column, value);
- }
-
- }
-
- private static class EmptyDataPermissionRule implements DataPermissionRule {
-
- @Override
- public Set getTableNames() {
- return Collections.emptySet();
- }
-
- @Override
- public Expression getExpression(String tableName, Alias tableAlias) {
- return null;
- }
-
- }
-
-}
diff --git a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandlerTest.java
similarity index 96%
rename from yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java
rename to yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandlerTest.java
index b8cad13cf..0b4ba791a 100644
--- a/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionDatabaseInterceptorTest2.java
+++ b/yudao-framework/yudao-spring-boot-starter-biz-data-permission/src/test/java/cn/iocoder/yudao/framework/datapermission/core/db/DataPermissionRuleHandlerTest.java
@@ -4,9 +4,11 @@ import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
+import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
@@ -21,24 +23,30 @@ import java.util.Set;
import static cn.iocoder.yudao.framework.common.util.collection.SetUtils.asSet;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
/**
- * {@link DataPermissionDatabaseInterceptor} 的单元测试
+ * {@link DataPermissionRuleHandler} 的单元测试
* 主要复用了 MyBatis Plus 的 TenantLineInnerInterceptorTest 的单元测试
* 不过它的单元测试不是很规范,考虑到是复用的,所以暂时不进行修改~
*
* @author 芋道源码
*/
-public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest {
+public class DataPermissionRuleHandlerTest extends BaseMockitoUnitTest {
@InjectMocks
- private DataPermissionDatabaseInterceptor interceptor;
+ private DataPermissionRuleHandler handler;
@Mock
private DataPermissionRuleFactory ruleFactory;
+ private DataPermissionInterceptor interceptor;
+
@BeforeEach
public void setUp() {
+ interceptor = new DataPermissionInterceptor(handler);
+
// 租户的数据权限规则
DataPermissionRule tenantRule = new DataPermissionRule() {
@@ -71,14 +79,14 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
- ExpressionList values = new ExpressionList(new LongValue(10L),
+ ExpressionList values = new ExpressionList<>(new LongValue(10L),
new LongValue(20L));
- return new InExpression(column, values);
+ return new InExpression(column, new Parenthesis((values)));
}
};
- // 设置到上下文,保证
- DataPermissionDatabaseInterceptor.ContextHolder.init(Arrays.asList(tenantRule, deptRule));
+ // 设置到上下文
+ when(ruleFactory.getDataPermissionRule(any())).thenReturn(Arrays.asList(tenantRule, deptRule));
}
@Test
@@ -262,7 +270,7 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
- "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
+ "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
@@ -447,7 +455,6 @@ public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest
assertEquals(targetSql, interceptor.parserSingle(sql, null));
}
-
// ========== 额外的测试 ==========
@Test
diff --git a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/mapper/BaseMapperX.java b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/mapper/BaseMapperX.java
index e7767c6f1..8dca318de 100644
--- a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/mapper/BaseMapperX.java
+++ b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/mapper/BaseMapperX.java
@@ -185,8 +185,8 @@ public interface BaseMapperX extends MPJBaseMapper {
return Db.updateBatchById(entities, size);
}
- default Boolean insertOrUpdate(T entity) {
- return Db.saveOrUpdate(entity);
+ default boolean insertOrUpdate(T entity) {
+ return Db.saveOrUpdate(entity);
}
default Boolean insertOrUpdateBatch(Collection collection) {
diff --git a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/type/JsonLongSetTypeHandler.java b/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/type/JsonLongSetTypeHandler.java
deleted file mode 100644
index 052c7232e..000000000
--- a/yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/type/JsonLongSetTypeHandler.java
+++ /dev/null
@@ -1,31 +0,0 @@
-package cn.iocoder.yudao.framework.mybatis.core.type;
-
-import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
-import com.baomidou.mybatisplus.extension.handlers.AbstractJsonTypeHandler;
-import com.fasterxml.jackson.core.type.TypeReference;
-
-import java.util.Set;
-
-/**
- * 参考 {@link com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler} 实现
- * 在我们将字符串反序列化为 Set 并且泛型为 Long 时,如果每个元素的数值太小,会被处理成 Integer 类型,导致可能存在隐性的 BUG。
- *
- * 例如说哦,SysUserDO 的 postIds 属性
- *
- * @author 芋道源码
- */
-public class JsonLongSetTypeHandler extends AbstractJsonTypeHandler