/* * Copyright (c) 2018-2028, Chill Zhuang All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * Neither the name of the dreamlu.net developer nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * Author: Chill 庄骞 (smallchill@163.com) */ package org.springblade.core.tenant; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils; import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler; import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.ToString; import net.sf.jsqlparser.expression.*; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; import net.sf.jsqlparser.expression.operators.relational.ItemsList; import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.statement.insert.Insert; import net.sf.jsqlparser.statement.select.*; import net.sf.jsqlparser.statement.update.Update; import org.springblade.core.secure.utils.AuthUtil; import org.springblade.core.tool.utils.CollectionUtil; import org.springblade.core.tool.utils.StringPool; import java.util.*; import java.util.stream.Collectors; /** * 租户拦截器 * * @author Chill */ @Data @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class BladeTenantInterceptor extends TenantLineInnerInterceptor { /** * 租户处理器 */ private TenantLineHandler tenantLineHandler; /** * 租户配置文件 */ private BladeTenantProperties tenantProperties; /** * 超管需要启用租户过滤的表 */ private List adminTenantTables = Arrays.asList("blade_top_menu", "blade_dict_biz"); @Override public void setTenantLineHandler(TenantLineHandler tenantLineHandler) { super.setTenantLineHandler(tenantLineHandler); this.tenantLineHandler = tenantLineHandler; } @Override protected void processInsert(Insert insert, int index, String sql, Object obj) { // 未启用租户增强,则使用原版逻辑 if (!tenantProperties.getEnhance()) { super.processInsert(insert, index, sql, obj); return; } if (tenantLineHandler.ignoreTable(insert.getTable().getName())) { // 过滤退出执行 return; } List columns = insert.getColumns(); if (CollectionUtils.isEmpty(columns)) { // 针对不给列名的insert 不处理 return; } String tenantIdColumn = tenantLineHandler.getTenantIdColumn(); if (columns.stream().map(Column::getColumnName).anyMatch(i -> i.equals(tenantIdColumn))) { // 针对已给出租户列的insert 不处理 return; } columns.add(new Column(tenantIdColumn)); // fixed gitee pulls/141 duplicate update List duplicateUpdateColumns = insert.getDuplicateUpdateExpressionList(); if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) { EqualsTo equalsTo = new EqualsTo(); equalsTo.setLeftExpression(new StringValue(tenantIdColumn)); equalsTo.setRightExpression(tenantLineHandler.getTenantId()); duplicateUpdateColumns.add(equalsTo); } Select select = insert.getSelect(); if (select != null) { this.processInsertSelect(select.getSelectBody()); } else if (insert.getItemsList() != null) { // fixed github pull/295 ItemsList itemsList = insert.getItemsList(); if (itemsList instanceof MultiExpressionList) { ((MultiExpressionList) itemsList).getExpressionLists().forEach(el -> el.getExpressions().add(tenantLineHandler.getTenantId())); } else { ((ExpressionList) itemsList).getExpressions().add(tenantLineHandler.getTenantId()); } } else { throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId"); } } /** * 处理 PlainSelect */ @Override protected void processPlainSelect(PlainSelect plainSelect) { //#3087 github List selectItems = plainSelect.getSelectItems(); if (CollectionUtils.isNotEmpty(selectItems)) { selectItems.forEach(this::processSelectItem); } // 处理 where 中的子查询 Expression where = plainSelect.getWhere(); processWhereSubSelect(where); // 处理 fromItem FromItem fromItem = plainSelect.getFromItem(); List list = processFromItem(fromItem); List
mainTables = new ArrayList<>(list); // 处理 join List joins = plainSelect.getJoins(); if (CollectionUtils.isNotEmpty(joins)) { mainTables = processJoins(mainTables, joins); } // 当有 mainTable 时,进行 where 条件追加 if (CollectionUtils.isNotEmpty(mainTables) && !doTenantFilters(mainTables)) { plainSelect.setWhere(builderExpression(where, mainTables)); } } /** * update 语句处理 */ @Override protected void processUpdate(Update update, int index, String sql, Object obj) { final Table table = update.getTable(); if (tenantLineHandler.ignoreTable(table.getName())) { // 过滤退出执行 return; } if (doTenantFilter(table.getName())) { // 过滤退出执行 return; } update.setWhere(this.andExpression(table, update.getWhere())); } /** * delete 语句处理 */ @Override protected void processDelete(Delete delete, int index, String sql, Object obj) { final Table table = delete.getTable(); if (tenantLineHandler.ignoreTable(table.getName())) { // 过滤退出执行 return; } if (doTenantFilter(table.getName())) { // 过滤退出执行 return; } delete.setWhere(this.andExpression(table, delete.getWhere())); } /** * delete update 语句 where 处理 */ @Override protected BinaryExpression andExpression(Table table, Expression where) { //获得条件表达式 EqualsTo equalsTo = new EqualsTo(); Expression leftExpression = this.getAliasColumn(table); Expression rightExpression = tenantLineHandler.getTenantId(); // 若是超管则不进行过滤 if (doTenantFilter(table.getName())) { leftExpression = rightExpression = new StringValue(StringPool.ONE); } equalsTo.setLeftExpression(leftExpression); equalsTo.setRightExpression(rightExpression); if (null != where) { if (where instanceof OrExpression) { return new AndExpression(equalsTo, new Parenthesis(where)); } else { return new AndExpression(equalsTo, where); } } return equalsTo; } /** * 增强插件使超级管理员可以看到所有租户数据 */ @Override protected Expression builderExpression(Expression currentExpression, List
tables) { // 没有表需要处理直接返回 if (CollectionUtils.isEmpty(tables)) { return currentExpression; } // 租户 Expression tenantId = tenantLineHandler.getTenantId(); // 构造每张表的条件 List equalsTos = tables.stream() // 租户忽略表 .filter(x -> !tenantLineHandler.ignoreTable(x.getName())) // 超管忽略表 .filter(x -> !doTenantFilter(x.getName())) .map(item -> new EqualsTo(getAliasColumn(item), tenantId)) .collect(Collectors.toList()); if (CollectionUtils.isEmpty(equalsTos)) { return currentExpression; } // 注入的表达式 Expression injectExpression = equalsTos.get(0); // 如果有多表,则用 and 连接 if (equalsTos.size() > 1) { for (int i = 1; i < equalsTos.size(); i++) { injectExpression = new AndExpression(injectExpression, equalsTos.get(i)); } } if (currentExpression == null) { return injectExpression; } if (currentExpression instanceof OrExpression) { return new AndExpression(new Parenthesis(currentExpression), injectExpression); } else { return new AndExpression(currentExpression, injectExpression); } } private List
processFromItem(FromItem fromItem) { // 处理括号括起来的表达式 while (fromItem instanceof ParenthesisFromItem) { fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); } List
mainTables = new ArrayList<>(); // 无 join 时的处理逻辑 if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; mainTables.add(fromTable); } else if (fromItem instanceof SubJoin) { // SubJoin 类型则还需要添加上 where 条件 List
tables = processSubJoin((SubJoin) fromItem); mainTables.addAll(tables); } else { // 处理下 fromItem processOtherFromItem(fromItem); } return mainTables; } /** * 处理 sub join * * @param subJoin subJoin * @return Table subJoin 中的主表 */ private List
processSubJoin(SubJoin subJoin) { List
mainTables = new ArrayList<>(); if (subJoin.getJoinList() != null) { List
list = processFromItem(subJoin.getLeft()); mainTables.addAll(list); mainTables = processJoins(mainTables, subJoin.getJoinList()); } return mainTables; } /** * 处理 joins * * @param mainTables 可以为 null * @param joins join 集合 * @return List
右连接查询的 Table 列表 */ private List
processJoins(List
mainTables, List joins) { // join 表达式中最终的主表 Table mainTable = null; // 当前 join 的左表 Table leftTable = null; if (mainTables == null) { mainTables = new ArrayList<>(); } else if (mainTables.size() == 1) { mainTable = mainTables.get(0); leftTable = mainTable; } //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 Deque> onTableDeque = new LinkedList<>(); for (Join join : joins) { // 处理 on 表达式 FromItem joinItem = join.getRightItem(); // 获取当前 join 的表,subJoint 可以看作是一张表 List
joinTables = null; if (joinItem instanceof Table) { joinTables = new ArrayList<>(); joinTables.add((Table) joinItem); } else if (joinItem instanceof SubJoin) { joinTables = processSubJoin((SubJoin) joinItem); } if (joinTables != null) { // 如果是隐式内连接 if (join.isSimple()) { mainTables.addAll(joinTables); continue; } // 当前表是否忽略 Table joinTable = joinTables.get(0); List
onTables = null; // 如果不要忽略,且是右连接,则记录下当前表 if (join.isRight()) { mainTable = joinTable; if (leftTable != null) { onTables = Collections.singletonList(leftTable); } } else if (join.isLeft()) { onTables = Collections.singletonList(joinTable); } else if (join.isInner()) { if (mainTable == null) { onTables = Collections.singletonList(joinTable); } else { onTables = Arrays.asList(mainTable, joinTable); } mainTable = null; } mainTables = new ArrayList<>(); if (mainTable != null) { mainTables.add(mainTable); } // 获取 join 尾缀的 on 表达式列表 Collection originOnExpressions = join.getOnExpressions(); // 正常 join on 表达式只有一个,立刻处理 if (originOnExpressions.size() == 1 && onTables != null) { List onExpressions = new LinkedList<>(); onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables)); join.setOnExpressions(onExpressions); leftTable = joinTable; continue; } // 表名压栈,忽略的表压入 null,以便后续不处理 onTableDeque.push(onTables); // 尾缀多个 on 表达式的时候统一处理 if (originOnExpressions.size() > 1) { Collection onExpressions = new LinkedList<>(); for (Expression originOnExpression : originOnExpressions) { List
currentTableList = onTableDeque.poll(); if (CollectionUtils.isEmpty(currentTableList)) { onExpressions.add(originOnExpression); } else { onExpressions.add(builderExpression(originOnExpression, currentTableList)); } } join.setOnExpressions(onExpressions); } leftTable = joinTable; } else { processOtherFromItem(joinItem); leftTable = null; } } return mainTables; } /** * 判断当前操作是否需要进行过滤 * * @param tableName 表名 */ public boolean doTenantFilter(String tableName) { return AuthUtil.isAdministrator() && !adminTenantTables.contains(tableName); } /** * 判断当前操作是否需要进行过滤 * * @param tables 表名 */ public boolean doTenantFilters(List
tables) { List tableNames = tables.stream().map(Table::getName).collect(Collectors.toList()); return AuthUtil.isAdministrator() && !CollectionUtil.containsAny(adminTenantTables, tableNames); } }