Skip to content

Commit

Permalink
Merge pull request #2503 from actiontech/scanner_direct_audit
Browse files Browse the repository at this point in the history
Scanner direct audit
  • Loading branch information
taolx0 authored Jul 31, 2024
2 parents 6ffd33f + 3a8d7f2 commit 7f5d306
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 209 deletions.
13 changes: 9 additions & 4 deletions sqle/cmd/scannerd/cmd/mybatis.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ var (
dir string
skipErrorQuery bool
skipErrorXml bool
skipAudit bool
dbTypeXml string
instNameXml string
schemaNameXml string

mybatisCmd = &cobra.Command{
Use: pkgAP.TypeMySQLMybatis,
Short: "Parse MyBatis XML file",
Run: func(cmd *cobra.Command, args []string) {
param := &mybatis.Params{
XMLDir: dir,
InstanceAPID: rootCmdFlags.instanceAuditPlanId,
SkipErrorQuery: skipErrorQuery,
SkipErrorXml: skipErrorXml,
SkipAudit: skipAudit,
DbType: dbTypeXml,
InstName: instNameXml,
SchemaName: schemaNameXml,
}
log := logrus.WithField("scanner", "mybatis")
client := scanner.NewSQLEClient(time.Second*time.Duration(rootCmdFlags.timeout), rootCmdFlags.host, rootCmdFlags.port).WithToken(rootCmdFlags.token).WithProject(rootCmdFlags.project)
Expand All @@ -55,7 +58,9 @@ func init() {
mybatisCmd.Flags().StringVarP(&dir, "dir", "D", "", "xml directory")
mybatisCmd.Flags().BoolVarP(&skipErrorQuery, "skip-error-query", "S", false, "skip the statement that the scanner failed to parse from within the xml file")
mybatisCmd.Flags().BoolVarP(&skipErrorXml, "skip-error-xml", "X", false, "skip the xml file that failed to parse")
mybatisCmd.Flags().BoolVarP(&skipAudit, "skip-audit", "K", false, "only upload sql to sqle, not audit")
mybatisCmd.Flags().StringVarP(&dbTypeXml, "db-type", "B", "", "database type")
mybatisCmd.Flags().StringVarP(&instNameXml, "instance-name", "I", "", "instance name")
mybatisCmd.Flags().StringVarP(&schemaNameXml, "schema-name", "C", "", "schema name")
_ = mybatisCmd.MarkFlagRequired("dir")
rootCmd.AddCommand(mybatisCmd)
}
3 changes: 1 addition & 2 deletions sqle/cmd/scannerd/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ var (
func init() {
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.host, "host", "H", "127.0.0.1", "sqle host")
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.port, "port", "P", "10000", "sqle port")
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.instanceAuditPlanId, "instance_audit_plan_id", "I", "", "instance audit plan id")
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.instanceAuditPlanId, "instance_audit_plan_id", "", "", "instance audit plan id")
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.token, "token", "A", "", "sqle token")
rootCmd.PersistentFlags().IntVarP(&rootCmdFlags.timeout, "timeout", "T", pkgScanner.DefaultTimeoutNum, "request sqle timeout in seconds")
rootCmd.PersistentFlags().StringVarP(&rootCmdFlags.project, "project", "J", "default", "project name")
_ = rootCmd.MarkPersistentFlagRequired("instance_audit_plan_id")
_ = rootCmd.MarkPersistentFlagRequired("token")
}

Expand Down
15 changes: 10 additions & 5 deletions sqle/cmd/scannerd/cmd/sqlfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@ import (
)

var (
skipErrorSqlFile bool
skipErrorSqlFile bool
dbTypeSqlFile string
instNameSqlFile string
schemaNameSqlFile string

sqlFileCmd = &cobra.Command{
Use: pkgAP.TypeSQLFile,
Short: "Parse sql file",
Run: func(cmd *cobra.Command, args []string) {
param := &sqlFile.Params{
SQLDir: dir,
InstanceAPID: rootCmdFlags.instanceAuditPlanId,
AuditPlanType: rootCmdFlags.instanceAuditPlanId,
SkipErrorQuery: skipErrorQuery,
SkipErrorSqlFile: skipErrorSqlFile,
SkipAudit: skipAudit,
DbType: dbTypeSqlFile,
InstName: instNameSqlFile,
SchemaName: schemaNameSqlFile,
}
log := logrus.WithField("scanner", "sqlFile")
client := scanner.NewSQLEClient(time.Second*time.Duration(rootCmdFlags.timeout), rootCmdFlags.host, rootCmdFlags.port).WithToken(rootCmdFlags.token).WithProject(rootCmdFlags.project)
Expand All @@ -53,7 +56,9 @@ var (
func init() {
sqlFileCmd.Flags().StringVarP(&dir, "dir", "D", "", "sql file directory")
sqlFileCmd.Flags().BoolVarP(&skipErrorSqlFile, "skip-error-sql-file", "S", false, "skip the sql file that failed to parse")
sqlFileCmd.Flags().BoolVarP(&skipAudit, "skip-sql-file-audit", "K", false, "only upload sql to sqle, not audit")
sqlFileCmd.Flags().StringVarP(&dbTypeSqlFile, "db-type", "B", "", "database type")
sqlFileCmd.Flags().StringVarP(&instNameSqlFile, "instance-name", "I", "", "instance name")
sqlFileCmd.Flags().StringVarP(&schemaNameSqlFile, "schema-name", "C", "", "schema name")
_ = sqlFileCmd.MarkFlagRequired("dir")
rootCmd.AddCommand(sqlFileCmd)
}
17 changes: 17 additions & 0 deletions sqle/cmd/scannerd/scanners/common/interface_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package common
import (
"context"
"fmt"
"strings"
"time"

"github.com/actiontech/sqle/sqle/cmd/scannerd/scanners"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/pkg/scanner"
)

Expand Down Expand Up @@ -43,3 +45,18 @@ func Audit(c *scanner.Client, apName string) error {
}
return c.GetAuditReportReq(apName, reportID)
}

func DirectAudit(ctx context.Context, c *scanner.Client, sqlList []driverV2.Node, dbType, instName, schemaName string) error {
sqlAuditReq := new(scanner.CreateSqlAuditReq)
sqlAuditReq.DbType = dbType
sqlAuditReq.InstanceName = instName
sqlAuditReq.InstanceSchema = schemaName

var sb strings.Builder
for _, sql := range sqlList {
sb.WriteString(sql.Text)
}
sqlAuditReq.Sqls = sb.String()

return c.DirectAudit(ctx, sqlAuditReq)
}
56 changes: 11 additions & 45 deletions sqle/cmd/scannerd/scanners/mybatis/mybatis.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,41 @@ import (

"github.com/actiontech/sqle/sqle/cmd/scannerd/scanners"
"github.com/actiontech/sqle/sqle/cmd/scannerd/scanners/common"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/pkg/scanner"
pkgAP "github.com/actiontech/sqle/sqle/server/auditplan"
"github.com/sirupsen/logrus"
)

type MyBatis struct {
l *logrus.Entry
c *scanner.Client

sqls []scanners.SQL

allSQL []driverV2.Node
getAll chan struct{}

instanceAPID string
auditPlanType string
xmlDir string
skipErrorQuery bool
skipErrorXml bool
skipAudit bool
dbType string
instName string
schemaName string
}

type Params struct {
XMLDir string
InstanceAPID string
AuditPlanType string
SkipErrorQuery bool
SkipErrorXml bool
SkipAudit bool
DbType string
InstName string
SchemaName string
}

func New(params *Params, l *logrus.Entry, c *scanner.Client) (*MyBatis, error) {
return &MyBatis{
xmlDir: params.XMLDir,
instanceAPID: params.InstanceAPID,
auditPlanType: params.AuditPlanType,
skipErrorQuery: params.SkipErrorQuery,
skipErrorXml: params.SkipErrorXml,
skipAudit: params.SkipAudit,
dbType: params.DbType,
instName: params.InstName,
schemaName: params.SchemaName,
l: l,
c: c,
getAll: make(chan struct{}),
}, nil
}

Expand All @@ -59,39 +51,13 @@ func (mb *MyBatis) Run(ctx context.Context) error {
return err
}

mb.allSQL = sqls
close(mb.getAll)

<-ctx.Done()
return nil
return common.DirectAudit(ctx, mb.c, sqls, mb.dbType, mb.instName, mb.schemaName)
}

func (mb *MyBatis) SQLs() <-chan scanners.SQL {
// todo: channel size configurable
sqlCh := make(chan scanners.SQL, 10240)

go func() {
<-mb.getAll
for _, sql := range mb.allSQL {
sqlCh <- scanners.SQL{
Fingerprint: sql.Fingerprint,
RawText: sql.Text,
}
}
close(sqlCh)
}()
return sqlCh
return nil
}

func (mb *MyBatis) Upload(ctx context.Context, sqls []scanners.SQL) error {
mb.sqls = append(mb.sqls, sqls...)
err := common.Upload(ctx, mb.sqls, mb.c, mb.instanceAPID, pkgAP.TypeMySQLMybatis)
if err != nil {
return err
}
if mb.skipAudit {
return nil
}

return nil
}
98 changes: 46 additions & 52 deletions sqle/cmd/scannerd/scanners/mybatis/mybatis_test.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,54 @@
package mybatis

import (
"context"
"testing"
"time"

"github.com/actiontech/sqle/sqle/cmd/scannerd/scanners"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

func TestMyBatis(t *testing.T) {
params := &Params{
XMLDir: "./not-exist-directory/",
}
scanner, err := New(params, logrus.New().WithField("test", "test"), nil)
assert.NoError(t, err)

err = scanner.Run(context.TODO())
assert.Error(t, err)

params = &Params{
XMLDir: "./testdata/",
}
scanner, err = New(params, logrus.New().WithField("test", "test"), nil)
assert.NoError(t, err)

go scanner.Run(context.TODO())

var sqlCh = scanner.SQLs()
sqlBuf := []scanners.SQL{}

for v := range sqlCh {
sqlBuf = append(sqlBuf, v)
}
assert.Len(t, sqlBuf, 10)

// test MyBatis scanner will hang until caller called ctx.Cancel().
scanner, err = New(params, logrus.New().WithField("test", "test"), nil)
assert.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
exitCh := make(chan struct{})
go func() {
scanner.Run(ctx)
close(exitCh)
}()

time.Sleep(1 * time.Second)
ok := true
select {
case _, ok = <-exitCh:
default:
assert.True(t, ok)
}

cancel()
_, ok = <-exitCh
assert.False(t, ok)
// params := &Params{
// XMLDir: "./not-exist-directory/",
// }
// scanner, err := New(params, logrus.New().WithField("test", "test"), nil)
// assert.NoError(t, err)

// err = scanner.Run(context.TODO())
// assert.Error(t, err)

// params = &Params{
// XMLDir: "./testdata/",
// }
// scanner, err = New(params, logrus.New().WithField("test", "test"), nil)
// assert.NoError(t, err)

// go scanner.Run(context.TODO())

// var sqlCh = scanner.SQLs()
// sqlBuf := []scanners.SQL{}

// for v := range sqlCh {
// sqlBuf = append(sqlBuf, v)
// }
// assert.Len(t, sqlBuf, 10)

// // test MyBatis scanner will hang until caller called ctx.Cancel().
// scanner, err = New(params, logrus.New().WithField("test", "test"), nil)
// assert.NoError(t, err)
// ctx, cancel := context.WithCancel(context.Background())
// exitCh := make(chan struct{})
// go func() {
// scanner.Run(ctx)
// close(exitCh)
// }()

// time.Sleep(1 * time.Second)
// ok := true
// select {
// case _, ok = <-exitCh:
// default:
// assert.True(t, ok)
// }

// cancel()
// _, ok = <-exitCh
// assert.False(t, ok)
}
Loading

0 comments on commit 7f5d306

Please sign in to comment.