Skip to content

Commit

Permalink
解决表名或字段名包裹导致无法获取索引信息和索引字段校验问题.
Browse files Browse the repository at this point in the history
  • Loading branch information
nieqiurong committed Apr 4, 2024
1 parent 90faa3e commit a84e466
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 18 deletions.
1 change: 1 addition & 0 deletions mybatis-plus-extension/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ dependencies {
testImplementation "com.google.guava:guava:33.0.0-jre"
testImplementation "io.github.classgraph:classgraph:4.8.165"
testImplementation "${lib.h2}"
testImplementation "${lib.mysql}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import lombok.Data;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
Expand Down Expand Up @@ -214,23 +215,25 @@ private void validJoins(List<Join> joins, Table table, Connection connection) {
private void validUseIndex(Table table, String columnName, Connection connection) {
//是否使用索引
boolean useIndexFlag = false;

String tableInfo = table.getName();
//表存在的索引
String dbName = null;
String tableName;
String[] tableArray = tableInfo.split("\\.");
if (tableArray.length == 1) {
tableName = tableArray[0];
} else {
dbName = tableArray[0];
tableName = tableArray[1];
}
List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
for (IndexInfo indexInfo : indexInfos) {
if (null != columnName && columnName.equalsIgnoreCase(indexInfo.getColumnName())) {
useIndexFlag = true;
break;
if (StringUtils.isNotBlank(columnName)) {
String tableInfo = table.getName();
//表存在的索引
String dbName = null;
String tableName;
String[] tableArray = tableInfo.split("\\.");
if (tableArray.length == 1) {
tableName = tableArray[0];
} else {
dbName = tableArray[0];
tableName = tableArray[1];
}
columnName = SqlParserUtils.removeWrapperSymbol(columnName);
List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
for (IndexInfo indexInfo : indexInfos) {
if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
useIndexFlag = true;
break;
}
}
}
if (!useIndexFlag) {
Expand Down Expand Up @@ -323,7 +326,7 @@ public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName
DatabaseMetaData metadata = conn.getMetaData();
String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
indexInfos = new ArrayList<>();
while (rs.next()) {
//索引中的列序列号等于1,才有效
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.baomidou.mybatisplus.extension.toolkit;

import com.baomidou.mybatisplus.core.toolkit.StringUtils;

/**
* SQL 解析工具类
*
Expand All @@ -32,4 +34,23 @@ public class SqlParserUtils {
public static String getOriginalCountSql(String originalSql) {
return String.format("SELECT COUNT(*) FROM (%s) TOTAL", originalSql);
}

/**
* 去除表或字段包裹符号
*
* @param tableOrColumn 表名或字段名
* @return str
* @since 3.5.6
*/
public static String removeWrapperSymbol(String tableOrColumn) {
if (StringUtils.isBlank(tableOrColumn)) {
return null;
}
if (tableOrColumn.startsWith("`") || tableOrColumn.startsWith("\"")
|| tableOrColumn.startsWith("[") || tableOrColumn.startsWith("<")) {
return tableOrColumn.substring(1, tableOrColumn.length() - 1);
}
return tableOrColumn;
}

}
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
package com.baomidou.mybatisplus.extension.plugins.inner;

import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.apache.ibatis.jdbc.SqlRunner;
import org.h2.jdbcx.JdbcDataSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.sql.Connection;
import java.sql.SQLException;

/**
* @author miemie
* @since 2022-04-11
*/
class IllegalSQLInnerInterceptorTest {

private final IllegalSQLInnerInterceptor interceptor = new IllegalSQLInnerInterceptor();
//
// 待研究为啥H2读不到索引信息
// private static Connection connection;
//
// @BeforeAll
// public static void beforeAll() throws SQLException {
// var jdbcDataSource = new JdbcDataSource();
// jdbcDataSource.setURL("jdbc:h2:mem:test;MODE=mysql;DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE");
// connection = jdbcDataSource.getConnection("sa","");
// var sql = """
// CREATE TABLE t_demo (
// `a` int DEFAULT NULL,
// `b` int DEFAULT NULL,
// `c` int DEFAULT NULL,
// KEY `ab_index` (`a`,`b`)
// );
// """;
// SqlRunner sqlRunner = new SqlRunner(connection);
// sqlRunner.run(sql);
// }

@Test
void test() {
Expand All @@ -20,6 +48,55 @@ void test() {
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete from t_user set age = 18", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age != 1", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age = 1 or name = 'test'", null));
// Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where a = 1 and b = 2", connection));
}

@Test
// @Disabled
void testMysql(){
/*
* CREATE TABLE `t_demo` (
`a` int DEFAULT NULL,
`b` int DEFAULT NULL,
`c` int DEFAULT NULL,
KEY `ab_index` (`a`,`b`)
);
CREATE TABLE `test` (
`a` int DEFAULT NULL,
`b` int DEFAULT NULL,
`c` int DEFAULT NULL,
KEY `ab_index` (`a`,`b`)
) ;
*/
var dataSource = new MysqlDataSource();
dataSource.setUrl("jdbc:mysql://127.0.0.1:3306/test?serverTimezone=Asia/Shanghai");
dataSource.setUser("root");
dataSource.setPassword("123456");

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from t_demo where `a` = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from t_demo where a = 1 and `b` = 2", dataSource.getConnection()));

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where `a` = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a = 1 and `b` = 2", dataSource.getConnection()));

Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where c = 3", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_demo where c = 3", dataSource.getConnection()));

Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.`t_demo` where c = 3", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.t_demo where c = 3", dataSource.getConnection()));

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `t_demo` a INNER JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `t_demo` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`t_demo` a INNER JOIN test.`test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`t_demo` a INNER JOIN test.`test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM t_demo a INNER JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM t_demo a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));

Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.baomidou.mybatisplus.test;

import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class SqlParserUtilsTest {

@Test
void testRemoveWrapperSymbol() {
//用SQLServer的人喜欢写这种
Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("[Demo]"), "Demo");
Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("Demo"), "Demo");
//mysql比较常见
Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("`Demo`"), "Demo");
//用关键字表的
Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("\"Demo\""), "Demo");
//这种少
Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("<Demo>"), "Demo");
}

}

0 comments on commit a84e466

Please sign in to comment.