Skip to content

Commit

Permalink
Merge pull request #108 from actiontech/issue-107
Browse files Browse the repository at this point in the history
fix: plugin adaptor do not use rules from SQLE when Audit
  • Loading branch information
sjjian authored Nov 25, 2021
2 parents 3d69edc + c661fc3 commit 5d24144
Showing 1 changed file with 44 additions and 56 deletions.
100 changes: 44 additions & 56 deletions sqle/pkg/driver/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,58 +23,44 @@ type Adaptor struct {

dt Dialector

rules map[*driver.Rule]rawSQLRuleHandler
rulesWithSQLparser map[*driver.Rule]astSQLRuleHandler
rules []*driver.Rule
ruleToRawHandler map[string] /*rule name*/ rawSQLRuleHandler
ruleToASTHandler map[string] /*rule name*/ astSQLRuleHandler

ao *adaptorOptions
}

type adaptorOptions struct {
dsn string
showDatabaseSQL string

dsnMaker func(*driver.DSN) string
sqlParser func(string) (interface{}, error)
}

func newAdaptorOptions(d Dialector, dsn *driver.DSN, opts ...AdaptorOption) *adaptorOptions {
ao := &adaptorOptions{}

_, ao.dsn = d.Dialect(dsn)
ao.showDatabaseSQL = d.ShowDatabaseSQL()

for _, opt := range opts {
opt.apply(ao)
}
if ao.dsnMaker != nil {
ao.dsn = ao.dsnMaker(dsn)
}
return ao
}

type rawSQLRuleHandler func(ctx context.Context, rule *driver.Rule, rawSQL string) (string, error)
type astSQLRuleHandler func(ctx context.Context, rule *driver.Rule, astSQL interface{}) (string, error)

// NewAdaptor create a database plugin Adaptor with dialector.
func NewAdaptor(dt Dialector) *Adaptor {
return &Adaptor{
ao: &adaptorOptions{},

dt: dt,
l: hclog.New(&hclog.LoggerOptions{
JSONFormat: true,
Output: os.Stderr,
Level: hclog.Trace,
}),
rules: make(map[*driver.Rule]rawSQLRuleHandler),
rulesWithSQLparser: make(map[*driver.Rule]astSQLRuleHandler),
ruleToRawHandler: make(map[string]rawSQLRuleHandler),
ruleToASTHandler: make(map[string]astSQLRuleHandler),
}
}

func (a *Adaptor) AddRule(r *driver.Rule, h rawSQLRuleHandler) {
a.rules[r] = h
a.rules = append(a.rules, r)
a.ruleToRawHandler[r.Name] = h
}

func (a *Adaptor) AddRuleWithSQLParser(r *driver.Rule, h astSQLRuleHandler) {
a.rulesWithSQLparser[r] = h
a.rules = append(a.rules, r)
a.ruleToASTHandler[r.Name] = h
}

func (a *Adaptor) Serve(opts ...AdaptorOption) {
Expand All @@ -84,35 +70,34 @@ func (a *Adaptor) Serve(opts ...AdaptorOption) {
}
}()

rules := make([]*driver.Rule, 0, len(a.rules))
for rule := range a.rules {
rules = append(rules, rule)
}
for rule := range a.rulesWithSQLparser {
rules = append(rules, rule)
for _, opt := range opts {
opt.apply(a.ao)
}

if len(rules) == 0 {
if len(a.rules) == 0 {
a.l.Info("no rule in plugin adaptor", "name", a.dt)
}

if len(a.ruleToASTHandler) != 0 && a.ao.sqlParser == nil {
panic("Add rule by AddRuleWithSQLParser(), but no SQL parser provided.")
}

r := &registererImpl{
dt: a.dt,
rules: rules,
rules: a.rules,
}

newDriver := func(cfg *driver.Config) driver.Driver {
a.cfg = cfg
a.ao = newAdaptorOptions(a.dt, cfg.DSN, opts...)

di := &driverImpl{a: a}

if cfg.DSN == nil {
return di
}

driverName, _ := a.dt.Dialect(cfg.DSN)
db, err := sql.Open(driverName, a.ao.dsn)
driverName, dsnDetail := a.dt.Dialect(cfg.DSN)
db, err := sql.Open(driverName, dsnDetail)
if err != nil {
panic(errors.Wrap(err, "open database failed when new driver"))
}
Expand Down Expand Up @@ -191,7 +176,6 @@ func (d *driverImpl) Close(ctx context.Context) {
if err := d.db.Close(); err != nil {
d.a.l.Error("failed to close database in driver adaptor", "err", err)
}
return
}

func (d *driverImpl) Ping(ctx context.Context) error {
Expand Down Expand Up @@ -248,7 +232,7 @@ func (d *driverImpl) Tx(ctx context.Context, sqls ...string) ([]_driver.Result,
}

func (d *driverImpl) Schemas(ctx context.Context) ([]string, error) {
rows, err := d.conn.QueryContext(ctx, d.a.ao.showDatabaseSQL)
rows, err := d.conn.QueryContext(ctx, d.a.dt.ShowDatabaseSQL())
if err != nil {
return nil, errors.Wrap(err, "query database in driver adaptor")
}
Expand Down Expand Up @@ -302,29 +286,33 @@ func classifySQL(sql string) (sqlType string) {
}

func (d *driverImpl) Audit(ctx context.Context, sql string) (*driver.AuditResult, error) {
result := driver.NewInspectResults()
var err error
var ast interface{}
if d.a.ao.sqlParser != nil {
ast, err = d.a.ao.sqlParser(sql)
if err != nil {
return nil, errors.Wrap(err, "parse sql")
}
}

if d.a.ao.sqlParser == nil {
for r, h := range d.a.rules {
msg, err := h(ctx, r, sql)
result := driver.NewInspectResults()
for _, rule := range d.a.cfg.Rules {
handler, ok := d.a.ruleToRawHandler[rule.Name]
if ok {
msg, err := handler(ctx, rule, sql)
if err != nil {
return nil, errors.Wrapf(err, "audit SQL %s in driver adaptor", sql)
}

result.Add(r.Level, msg)
}
} else {
ast, err := d.a.ao.sqlParser(sql)
if err != nil {
return nil, errors.Wrapf(err, "parse SQL %s in driver adaptor", sql)
}
for r, h := range d.a.rulesWithSQLparser {
msg, err := h(ctx, r, ast)
if err != nil {
return nil, errors.Wrapf(err, "audit SQL %s with SQL parser in driver adaptor", sql)
result.Add(rule.Level, msg)
} else {
handler, ok := d.a.ruleToASTHandler[rule.Name]
if ok {
msg, err := handler(ctx, rule, ast)
if err != nil {
return nil, errors.Wrapf(err, "audit SQL %s in driver adaptor", sql)
}
result.Add(rule.Level, msg)
}

result.Add(r.Level, msg)
}
}

Expand Down

0 comments on commit 5d24144

Please sign in to comment.