diff --git a/sqle/pkg/driver/adaptor.go b/sqle/pkg/driver/adaptor.go index 5f8835283b..994a3ea00b 100644 --- a/sqle/pkg/driver/adaptor.go +++ b/sqle/pkg/driver/adaptor.go @@ -23,58 +23,44 @@ type Adaptor struct { dt Dialector - rules map[*driver.Rule]rawSQLRuleHandler - rulesWithSQLparser map[*driver.Rule]astSQLRuleHandler + rules []*driver.Rule + ruleToRawHandler map[string] /*rule name*/ rawSQLRuleHandler + ruleToASTHandler map[string] /*rule name*/ astSQLRuleHandler ao *adaptorOptions } type adaptorOptions struct { - dsn string - showDatabaseSQL string - - dsnMaker func(*driver.DSN) string sqlParser func(string) (interface{}, error) } -func newAdaptorOptions(d Dialector, dsn *driver.DSN, opts ...AdaptorOption) *adaptorOptions { - ao := &adaptorOptions{} - - _, ao.dsn = d.Dialect(dsn) - ao.showDatabaseSQL = d.ShowDatabaseSQL() - - for _, opt := range opts { - opt.apply(ao) - } - if ao.dsnMaker != nil { - ao.dsn = ao.dsnMaker(dsn) - } - return ao -} - type rawSQLRuleHandler func(ctx context.Context, rule *driver.Rule, rawSQL string) (string, error) type astSQLRuleHandler func(ctx context.Context, rule *driver.Rule, astSQL interface{}) (string, error) // NewAdaptor create a database plugin Adaptor with dialector. func NewAdaptor(dt Dialector) *Adaptor { return &Adaptor{ + ao: &adaptorOptions{}, + dt: dt, l: hclog.New(&hclog.LoggerOptions{ JSONFormat: true, Output: os.Stderr, Level: hclog.Trace, }), - rules: make(map[*driver.Rule]rawSQLRuleHandler), - rulesWithSQLparser: make(map[*driver.Rule]astSQLRuleHandler), + ruleToRawHandler: make(map[string]rawSQLRuleHandler), + ruleToASTHandler: make(map[string]astSQLRuleHandler), } } func (a *Adaptor) AddRule(r *driver.Rule, h rawSQLRuleHandler) { - a.rules[r] = h + a.rules = append(a.rules, r) + a.ruleToRawHandler[r.Name] = h } func (a *Adaptor) AddRuleWithSQLParser(r *driver.Rule, h astSQLRuleHandler) { - a.rulesWithSQLparser[r] = h + a.rules = append(a.rules, r) + a.ruleToASTHandler[r.Name] = h } func (a *Adaptor) Serve(opts ...AdaptorOption) { @@ -84,26 +70,25 @@ func (a *Adaptor) Serve(opts ...AdaptorOption) { } }() - rules := make([]*driver.Rule, 0, len(a.rules)) - for rule := range a.rules { - rules = append(rules, rule) - } - for rule := range a.rulesWithSQLparser { - rules = append(rules, rule) + for _, opt := range opts { + opt.apply(a.ao) } - if len(rules) == 0 { + if len(a.rules) == 0 { a.l.Info("no rule in plugin adaptor", "name", a.dt) } + if len(a.ruleToASTHandler) != 0 && a.ao.sqlParser == nil { + panic("Add rule by AddRuleWithSQLParser(), but no SQL parser provided.") + } + r := ®istererImpl{ dt: a.dt, - rules: rules, + rules: a.rules, } newDriver := func(cfg *driver.Config) driver.Driver { a.cfg = cfg - a.ao = newAdaptorOptions(a.dt, cfg.DSN, opts...) di := &driverImpl{a: a} @@ -111,8 +96,8 @@ func (a *Adaptor) Serve(opts ...AdaptorOption) { return di } - driverName, _ := a.dt.Dialect(cfg.DSN) - db, err := sql.Open(driverName, a.ao.dsn) + driverName, dsnDetail := a.dt.Dialect(cfg.DSN) + db, err := sql.Open(driverName, dsnDetail) if err != nil { panic(errors.Wrap(err, "open database failed when new driver")) } @@ -191,7 +176,6 @@ func (d *driverImpl) Close(ctx context.Context) { if err := d.db.Close(); err != nil { d.a.l.Error("failed to close database in driver adaptor", "err", err) } - return } func (d *driverImpl) Ping(ctx context.Context) error { @@ -248,7 +232,7 @@ func (d *driverImpl) Tx(ctx context.Context, sqls ...string) ([]_driver.Result, } func (d *driverImpl) Schemas(ctx context.Context) ([]string, error) { - rows, err := d.conn.QueryContext(ctx, d.a.ao.showDatabaseSQL) + rows, err := d.conn.QueryContext(ctx, d.a.dt.ShowDatabaseSQL()) if err != nil { return nil, errors.Wrap(err, "query database in driver adaptor") } @@ -302,29 +286,33 @@ func classifySQL(sql string) (sqlType string) { } func (d *driverImpl) Audit(ctx context.Context, sql string) (*driver.AuditResult, error) { - result := driver.NewInspectResults() + var err error + var ast interface{} + if d.a.ao.sqlParser != nil { + ast, err = d.a.ao.sqlParser(sql) + if err != nil { + return nil, errors.Wrap(err, "parse sql") + } + } - if d.a.ao.sqlParser == nil { - for r, h := range d.a.rules { - msg, err := h(ctx, r, sql) + result := driver.NewInspectResults() + for _, rule := range d.a.cfg.Rules { + handler, ok := d.a.ruleToRawHandler[rule.Name] + if ok { + msg, err := handler(ctx, rule, sql) if err != nil { return nil, errors.Wrapf(err, "audit SQL %s in driver adaptor", sql) } - - result.Add(r.Level, msg) - } - } else { - ast, err := d.a.ao.sqlParser(sql) - if err != nil { - return nil, errors.Wrapf(err, "parse SQL %s in driver adaptor", sql) - } - for r, h := range d.a.rulesWithSQLparser { - msg, err := h(ctx, r, ast) - if err != nil { - return nil, errors.Wrapf(err, "audit SQL %s with SQL parser in driver adaptor", sql) + result.Add(rule.Level, msg) + } else { + handler, ok := d.a.ruleToASTHandler[rule.Name] + if ok { + msg, err := handler(ctx, rule, ast) + if err != nil { + return nil, errors.Wrapf(err, "audit SQL %s in driver adaptor", sql) + } + result.Add(rule.Level, msg) } - - result.Add(r.Level, msg) } }