diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDmlExportController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDmlExportController.java index cbb263180..fa872990b 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDmlExportController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDmlExportController.java @@ -7,6 +7,7 @@ import java.time.LocalDateTime; import java.util.List; +import ai.chat2db.spi.model.Header; import com.alibaba.druid.DbType; import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.SQLUtils.FormatOption; @@ -103,9 +104,9 @@ public void export(@Valid @RequestBody DataExportRequest request, HttpServletRes response.setCharacterEncoding("utf-8"); String fileName = URLEncoder.encode( - tableName + "_" + LocalDateTime.now().format(DatePattern.PURE_DATETIME_FORMATTER), - StandardCharsets.UTF_8) - .replaceAll("\\+", "%20"); + tableName + "_" + LocalDateTime.now().format(DatePattern.PURE_DATETIME_FORMATTER), + StandardCharsets.UTF_8) + .replaceAll("\\+", "%20"); if (exportType == ExportTypeEnum.CSV) { doExportCsv(sql, response, fileName); @@ -115,19 +116,19 @@ public void export(@Valid @RequestBody DataExportRequest request, HttpServletRes } private void doExportCsv(String sql, HttpServletResponse response, String fileName) - throws IOException { + throws IOException { response.setContentType("text/csv"); response.setHeader("Content-disposition", "attachment;filename*=utf-8''" + fileName + ".csv"); ExcelWrapper excelWrapper = new ExcelWrapper(); try { ExcelWriterBuilder excelWriterBuilder = EasyExcel.write(response.getOutputStream()) - .charset(StandardCharsets.UTF_8) - .excelType(ExcelTypeEnum.CSV); + .charset(StandardCharsets.UTF_8) + .excelType(ExcelTypeEnum.CSV); excelWrapper.setExcelWriterBuilder(excelWriterBuilder); SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql, headerList -> { excelWriterBuilder.head( - EasyCollectionUtils.toList(headerList, header -> Lists.newArrayList(header.getName()))); + EasyCollectionUtils.toList(headerList, header -> Lists.newArrayList(header.getName()))); excelWrapper.setExcelWriter(excelWriterBuilder.build()); excelWrapper.setWriteSheet(EasyExcel.writerSheet(0).build()); }, dataList -> { @@ -143,29 +144,27 @@ private void doExportCsv(String sql, HttpServletResponse response, String fileNa } private void doExportInsert(String sql, HttpServletResponse response, String fileName, DbType dbType, - String tableName) - throws IOException { + String tableName) + throws IOException { response.setContentType("text/sql"); response.setHeader("Content-disposition", "attachment;filename*=utf-8''" + fileName + ".sql"); try (PrintWriter printWriter = response.getWriter()) { InsertWrapper insertWrapper = new InsertWrapper(); SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql, - headerList -> insertWrapper.setHeaderList( - EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName()))) - , dataList -> { - SQLInsertStatement sqlInsertStatement = new SQLInsertStatement(); - sqlInsertStatement.setDbType(dbType); - sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName)); - sqlInsertStatement.getColumns().addAll(insertWrapper.getHeaderList()); - ValuesClause valuesClause = new ValuesClause(); - for (String s : dataList) { - valuesClause.addValue(s); - } - sqlInsertStatement.setValues(valuesClause); - - printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";"); - }, false); + headerList -> insertWrapper.setHeaderList(headerList) + , dataList -> { + SQLInsertStatement sqlInsertStatement = new SQLInsertStatement(); + sqlInsertStatement.setDbType(dbType); + sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName)); + List
headerList = insertWrapper.getHeaderList(); + sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName()))); + ValuesClause valuesClause = SqlUtils.getValuesClause(dataList, headerList); + sqlInsertStatement.setValues(valuesClause); + + printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";"); + + }, false); } } @@ -174,7 +173,7 @@ private void doExportInsert(String sql, HttpServletResponse response, String fil @NoArgsConstructor @AllArgsConstructor public static class InsertWrapper { - private List headerList; + private List
headerList; } @Data diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/task/biz/TaskBizService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/task/biz/TaskBizService.java index d52e8cecf..c4297c74c 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/task/biz/TaskBizService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/task/biz/TaskBizService.java @@ -19,6 +19,7 @@ import ai.chat2db.server.web.api.controller.rdb.factory.ExportServiceFactory; import ai.chat2db.server.web.api.controller.rdb.request.DataExportRequest; import ai.chat2db.server.web.api.controller.rdb.vo.TableVO; +import ai.chat2db.spi.model.Header; import ai.chat2db.spi.model.Table; import ai.chat2db.spi.sql.Chat2DBContext; import ai.chat2db.spi.sql.ConnectInfo; @@ -273,20 +274,18 @@ private void doExportInsert(String sql, File file, DbType dbType, try (PrintWriter printWriter = new PrintWriter(file, StandardCharsets.UTF_8.name())) { RdbDmlExportController.InsertWrapper insertWrapper = new RdbDmlExportController.InsertWrapper(); SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql, - headerList -> insertWrapper.setHeaderList( - EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName()))) + headerList -> insertWrapper.setHeaderList(headerList) , dataList -> { SQLInsertStatement sqlInsertStatement = new SQLInsertStatement(); sqlInsertStatement.setDbType(dbType); sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName)); - sqlInsertStatement.getColumns().addAll(insertWrapper.getHeaderList()); - SQLInsertStatement.ValuesClause valuesClause = new SQLInsertStatement.ValuesClause(); - for (String s : dataList) { - valuesClause.addValue(s); - } + List
headerList = insertWrapper.getHeaderList(); + sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName()))); + SQLInsertStatement.ValuesClause valuesClause = SqlUtils.getValuesClause(dataList, headerList); sqlInsertStatement.setValues(valuesClause); printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";"); + }, false); } } diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/jdbc/DefaultSqlBuilder.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/jdbc/DefaultSqlBuilder.java index 24f7b13c7..97c7aa0aa 100644 --- a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/jdbc/DefaultSqlBuilder.java +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/jdbc/DefaultSqlBuilder.java @@ -1,11 +1,19 @@ package ai.chat2db.spi.jdbc; +import ai.chat2db.server.tools.common.util.EasyCollectionUtils; import ai.chat2db.spi.MetaData; import ai.chat2db.spi.SqlBuilder; import ai.chat2db.spi.enums.DmlType; import ai.chat2db.spi.model.*; import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.util.JdbcUtils; import ai.chat2db.spi.util.SqlUtils; +import com.alibaba.druid.DbType; +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLExpr; +import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; +import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; +import com.alibaba.druid.sql.ast.statement.*; import com.google.common.collect.Lists; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.statement.Statement; @@ -49,6 +57,7 @@ public String pageLimit(String sql, int offset, int pageNo, int pageSize) { } public static String CREATE_DATABASE_SQL = "CREATE DATABASE IF NOT EXISTS `%s` DEFAULT CHARACTER SET %s COLLATE %s"; + private static final SQLUtils.FormatOption FORMAT_OPTION = new SQLUtils.FormatOption(true, false); @Override public String buildCreateDatabaseSql(Database database) { @@ -116,14 +125,14 @@ public String buildSqlByQuery(QueryResult queryResult) { String sql = ""; if ("UPDATE".equalsIgnoreCase(operation.getType())) { sql = getUpdateSql(tableName, headerList, row, odlRow, metaSchema, keyColumns, false); - if("MYSQL".equalsIgnoreCase(dbType)){ + if ("MYSQL".equalsIgnoreCase(dbType)) { sql = sql + " LIMIT 1"; } } else if ("CREATE".equalsIgnoreCase(operation.getType())) { sql = getInsertSql(tableName, headerList, row, metaSchema); } else if ("DELETE".equalsIgnoreCase(operation.getType())) { sql = getDeleteSql(tableName, headerList, odlRow, metaSchema, keyColumns); - if("MYSQL".equalsIgnoreCase(dbType)){ + if ("MYSQL".equalsIgnoreCase(dbType)) { sql = sql + " LIMIT 1"; } } else if ("UPDATE_COPY".equalsIgnoreCase(operation.getType())) { @@ -310,10 +319,12 @@ private List getPrimaryColumns(List
headerList) { private String getDeleteSql(String tableName, List
headerList, List row, MetaData metaSchema, List keyColumns) { - StringBuilder script = new StringBuilder(); - script.append("DELETE FROM ").append(tableName).append(""); - script.append(buildWhere(headerList, row, metaSchema, keyColumns)); - return script.toString(); + SQLDeleteStatement sqlDeleteStatement = new SQLDeleteStatement(); + sqlDeleteStatement.setTableSource(new SQLExprTableSource(tableName)); + sqlDeleteStatement.setWhere(buildWhereExpr(headerList, row, metaSchema, keyColumns)); + DbType dbType = JdbcUtils.parse2DruidDbType(Chat2DBContext.getConnectInfo().getDbType()); + String deleteSql = SQLUtils.toSQLString(sqlDeleteStatement, dbType, FORMAT_OPTION); + return deleteSql; } private String buildWhere(List
headerList, List row, MetaData metaSchema, List keyColumns) { @@ -357,34 +368,55 @@ private String buildWhere(List
headerList, List row, MetaData me return script.toString(); } + private SQLExpr buildWhereExpr(List
headerList, List row, MetaData metaSchema, List keyColumns) { + List conditions = new ArrayList<>(); + + if (CollectionUtils.isEmpty(keyColumns)) { + for (int i = 1; i < row.size(); i++) { + String oldValue = row.get(i); + Header header = headerList.get(i); + if (oldValue == null) { + conditions.add(SQLBinaryOpExpr.isNull(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())))); + } else { + conditions.add(SQLBinaryOpExpr.eq(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())), + SqlUtils.getSqlExpr(oldValue, header.getDataType()))); + } + } + } else { + for (int i = 1; i < row.size(); i++) { + String oldValue = row.get(i); + Header header = headerList.get(i); + String columnName = header.getName(); + if (keyColumns.contains(columnName)) { + if (oldValue == null) { + conditions.add(SQLBinaryOpExpr.isNull(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())))); + } else { + conditions.add(SQLBinaryOpExpr.eq(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())), + SqlUtils.getSqlExpr(oldValue, header.getDataType()))); + } + } + } + } + SQLExpr expr = null; + for (SQLBinaryOpExpr condition : conditions) { + expr = SQLBinaryOpExpr.and(expr, condition); + } + + return expr; + } + private String getInsertSql(String tableName, List
headerList, List row, MetaData metaSchema) { if (CollectionUtils.isEmpty(row) || ObjectUtils.allNull(row.toArray())) { return ""; } - StringBuilder script = new StringBuilder(); - script.append("INSERT INTO ").append(tableName) - .append(" ("); - for (int i = 1; i < row.size(); i++) { - Header header = headerList.get(i); - //String newValue = row.get(i); - //if (newValue != null) { - script.append(metaSchema.getMetaDataName(header.getName())) - .append(","); - // } - } - script.deleteCharAt(script.length() - 1); - script.append(") VALUES ("); - for (int i = 1; i < row.size(); i++) { - String newValue = row.get(i); - //if (newValue != null) { - Header header = headerList.get(i); - script.append(SqlUtils.getSqlValue(newValue, header.getDataType())) - .append(","); - //} - } - script.deleteCharAt(script.length() - 1); - script.append(")"); - return script.toString(); + DbType dbType = JdbcUtils.parse2DruidDbType(Chat2DBContext.getConnectInfo().getDbType()); + SQLInsertStatement sqlInsertStatement = new SQLInsertStatement(); + sqlInsertStatement.setDbType(dbType); + sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName)); + sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())))); + SQLInsertStatement.ValuesClause valuesClause = SqlUtils.getValuesClause(row, headerList); + sqlInsertStatement.setValues(valuesClause); + return SQLUtils.toSQLString(sqlInsertStatement, dbType, FORMAT_OPTION); } @@ -395,6 +427,10 @@ private String getUpdateSql(String tableName, List
headerList, List headerList, List row, List
headerList) { + SQLInsertStatement.ValuesClause valuesClause = new SQLInsertStatement.ValuesClause(); + for (int i = 0; i < row.size(); i++) { + String s = row.get(i); + Header header = headerList.get(i); + valuesClause.addValue(getSqlExpr(s, header.getDataType())); + } + return valuesClause; + } + + public static SQLValuableExpr getSqlExpr(String value, String dataType) { + if (value == null) { + return new SQLNullExpr(); + } else if (DataTypeEnum.getByCode(dataType).equals(DataTypeEnum.DATETIME)) { + return new SQLTimestampExpr(value); + } else { + return new SQLCharExpr(value); + } + + + } + private static SQLTableSource getSQLExprTableSource(SQLTableSource sqlTableSource) { if (sqlTableSource instanceof SQLExprTableSource sqlExprTableSource) { return sqlExprTableSource; @@ -140,8 +164,8 @@ public static List parse(String sql, DbType dbType) { List sqls = sqlSplitter.split(sql); return sqls.stream().map(splitSqlString -> SQLParserUtils.removeComment(splitSqlString.getStr(), dbType)).collect(Collectors.toList()); } - }catch (Exception e){ - log.error("sqlSplitter error",e); + } catch (Exception e) { + log.error("sqlSplitter error", e); } try { if (DbType.mysql.equals(dbType) || @@ -152,8 +176,8 @@ public static List parse(String sql, DbType dbType) { sqlSplitProcessor.setDelimiter(";"); return split(sqlSplitProcessor, sql, dbType); } - }catch (Exception e){ - log.error("sqlSplitProcessor error",e); + } catch (Exception e) { + log.error("sqlSplitProcessor error", e); } // sql = removeDelimiter(sql); if (StringUtils.isBlank(sql)) { @@ -246,7 +270,7 @@ public static String getSqlValue(String value, String dataType) { if (value == null) { return null; } - if("".equals(value)){ + if ("".equals(value)) { return "''"; } if (DEFAULT_VALUE.equals(value)) {