-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MySQL/MariaDB: Use strict SQL mode instead of just ANSI_QUOTES #624
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,8 @@ func (d *Database) Open(logger *logging.Logger) (*icingadb.DB, error) { | |
|
||
config.DBName = d.Database | ||
config.Timeout = time.Minute | ||
config.Params = map[string]string{"sql_mode": "ANSI_QUOTES"} | ||
// Set strict SQL mode, i.e. trigger an error if an incorrect value is inserted into a column. | ||
config.Params = map[string]string{"sql_mode": "TRADITIONAL"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For 1.1.1 we should just set
|
||
|
||
tlsConfig, err := d.TlsOptions.MakeConfig(d.Host) | ||
if err != nil { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ type DB struct { | |
logger *logging.Logger | ||
tableSemaphores map[string]*semaphore.Weighted | ||
tableSemaphoresMu sync.Mutex | ||
quoter *Quoter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering: why hide this inside a member instead of having the functions directly available on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't think of that tbh. |
||
} | ||
|
||
// Options define user configurable database options. | ||
|
@@ -81,6 +82,7 @@ func NewDb(db *sqlx.DB, logger *logging.Logger, options *Options) *DB { | |
logger: logger, | ||
Options: options, | ||
tableSemaphores: make(map[string]*semaphore.Weighted), | ||
quoter: NewQuoter(db), | ||
} | ||
} | ||
|
||
|
@@ -136,8 +138,8 @@ func (db *DB) BuildColumns(subject interface{}) []string { | |
// BuildDeleteStmt returns a DELETE statement for the given struct. | ||
func (db *DB) BuildDeleteStmt(from interface{}) string { | ||
return fmt.Sprintf( | ||
`DELETE FROM "%s" WHERE id IN (?)`, | ||
utils.TableName(from), | ||
"DELETE FROM %s WHERE id IN (?)", | ||
db.quoter.QuoteIdentifier(utils.TableName(from)), | ||
) | ||
} | ||
|
||
|
@@ -146,9 +148,9 @@ func (db *DB) BuildInsertStmt(into interface{}) (string, int) { | |
columns := db.BuildColumns(into) | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s)`, | ||
utils.TableName(into), | ||
strings.Join(columns, `", "`), | ||
"INSERT INTO %s (%s) VALUES (%s)", | ||
db.quoter.QuoteIdentifier(utils.TableName(into)), | ||
db.quoter.QuoteColumnList(columns), | ||
fmt.Sprintf(":%s", strings.Join(columns, ", :")), | ||
), len(columns) | ||
} | ||
|
@@ -163,15 +165,15 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { | |
switch db.DriverName() { | ||
case driver.MySQL: | ||
// MySQL treats UPDATE id = id as a no-op. | ||
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) | ||
clause = fmt.Sprintf("ON DUPLICATE KEY UPDATE %[1]s = %[1]s", db.quoter.QuoteIdentifier(columns[0])) | ||
case driver.PostgreSQL: | ||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table) | ||
} | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s) %s`, | ||
table, | ||
strings.Join(columns, `", "`), | ||
"INSERT INTO %s (%s) VALUES (%s) %s", | ||
db.quoter.QuoteIdentifier(table), | ||
db.quoter.QuoteColumnList(columns), | ||
fmt.Sprintf(":%s", strings.Join(columns, ", :")), | ||
clause, | ||
), len(columns) | ||
|
@@ -181,14 +183,14 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { | |
// and the column list from the specified columns struct. | ||
func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string { | ||
q := fmt.Sprintf( | ||
`SELECT "%s" FROM "%s"`, | ||
strings.Join(db.BuildColumns(columns), `", "`), | ||
utils.TableName(table), | ||
"SELECT %s FROM %s", | ||
db.quoter.QuoteColumnList(db.BuildColumns(columns)), | ||
db.quoter.QuoteIdentifier(utils.TableName(table)), | ||
) | ||
|
||
if scoper, ok := table.(contracts.Scoper); ok { | ||
where, _ := db.BuildWhere(scoper.Scope()) | ||
q += ` WHERE ` + where | ||
q += " WHERE " + where | ||
} | ||
|
||
return q | ||
|
@@ -197,16 +199,11 @@ func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string { | |
// BuildUpdateStmt returns an UPDATE statement for the given struct. | ||
func (db *DB) BuildUpdateStmt(update interface{}) (string, int) { | ||
columns := db.BuildColumns(update) | ||
set := make([]string, 0, len(columns)) | ||
|
||
for _, col := range columns { | ||
set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col)) | ||
} | ||
|
||
return fmt.Sprintf( | ||
`UPDATE "%s" SET %s WHERE id = :id`, | ||
utils.TableName(update), | ||
strings.Join(set, ", "), | ||
"UPDATE %s SET %s WHERE id = :id", | ||
db.quoter.QuoteIdentifier(utils.TableName(update)), | ||
strings.Join(db.quoter.BuildAssignmentList(columns), ", "), | ||
), len(columns) + 1 // +1 because of WHERE id = :id | ||
} | ||
|
||
|
@@ -226,38 +223,32 @@ func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders in | |
switch db.DriverName() { | ||
case driver.MySQL: | ||
clause = "ON DUPLICATE KEY UPDATE" | ||
setFormat = `"%[1]s" = VALUES("%[1]s")` | ||
setFormat = fmt.Sprintf("%[1]s = VALUES(%[1]s)", db.quoter.QuoteIdentifier("%[1]s")) | ||
case driver.PostgreSQL: | ||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table) | ||
setFormat = `"%[1]s" = EXCLUDED."%[1]s"` | ||
setFormat = fmt.Sprintf("%[1]s = EXCLUDED.%[1]s", db.quoter.QuoteIdentifier("%[1]s")) | ||
} | ||
|
||
set := make([]string, 0, len(updateColumns)) | ||
|
||
for _, col := range updateColumns { | ||
set = append(set, fmt.Sprintf(setFormat, col)) | ||
} | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, | ||
table, | ||
strings.Join(insertColumns, `", "`), | ||
fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")), | ||
"INSERT INTO %s (%s) VALUES (%s) %s %s", | ||
db.quoter.QuoteIdentifier(table), | ||
db.quoter.QuoteColumnList(insertColumns), | ||
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")), | ||
clause, | ||
strings.Join(set, ","), | ||
strings.Join(set, ", "), | ||
), len(insertColumns) | ||
} | ||
|
||
// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct | ||
// combined with the AND operator. | ||
func (db *DB) BuildWhere(subject interface{}) (string, int) { | ||
columns := db.BuildColumns(subject) | ||
where := make([]string, 0, len(columns)) | ||
for _, col := range columns { | ||
where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col)) | ||
} | ||
|
||
return strings.Join(where, ` AND `), len(columns) | ||
return strings.Join(db.quoter.BuildAssignmentList(columns), " AND "), len(columns) | ||
} | ||
|
||
// OnSuccess is a callback for successful (bulk) DML operations. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a few tests for this file should be easy by initializing with |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package icingadb | ||
|
||
import ( | ||
"fmt" | ||
"github.com/icinga/icingadb/pkg/driver" | ||
"github.com/jmoiron/sqlx" | ||
"strings" | ||
) | ||
|
||
// Quoter provides utility functions for quoting table names and columns, | ||
// where the quote character depends on the database driver used. | ||
type Quoter struct { | ||
quoteCharacter string | ||
} | ||
|
||
// NewQuoter creates and returns a new Quoter | ||
// carrying the quote character appropriate for the given database connection. | ||
func NewQuoter(db *sqlx.DB) *Quoter { | ||
var qc string | ||
|
||
switch db.DriverName() { | ||
case driver.MySQL: | ||
qc = "`" | ||
case driver.PostgreSQL: | ||
qc = `"` | ||
default: | ||
panic("unknown driver " + db.DriverName()) | ||
} | ||
|
||
return &Quoter{quoteCharacter: qc} | ||
} | ||
|
||
// BuildAssignmentList quotes the specified columns into `column = :column` pairs for safe use in named query parts, | ||
// i.e. `UPDATE ... SET assignment_list` and `SELECT ... WHERE where_condition`. | ||
func (q *Quoter) BuildAssignmentList(columns []string) []string { | ||
assign := make([]string, 0, len(columns)) | ||
for _, col := range columns { | ||
assign = append(assign, fmt.Sprintf("%s = :%s", q.QuoteIdentifier(col), col)) | ||
} | ||
|
||
return assign | ||
} | ||
|
||
// QuoteColumnList quotes the given columns into a single comma concatenated string | ||
// so that they can be safely used as a column list for SELECT and INSERT statements. | ||
func (q *Quoter) QuoteColumnList(columns []string) string { | ||
return fmt.Sprintf("%[1]s%s%[1]s", q.quoteCharacter, strings.Join(columns, q.quoteCharacter+", "+q.quoteCharacter)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mixing numbered and unnumbered placeholders looks strange. At least I can't tell without double-checking which argument an unnumbered placeholder would refer to after some numbered ones. So maybe at least go with But it's probably easier to not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can do that. |
||
} | ||
|
||
// QuoteIdentifier quotes the given identifier so that it can be safely used as table or column name, | ||
// even if it is a reserved name where the quote character depends on the database driver used. | ||
func (q *Quoter) QuoteIdentifier(identifier string) string { | ||
return q.quoteCharacter + identifier + q.quoteCharacter | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do differently compared to not setting
sql_mode
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to explicitly enable strict mode instead of letting something else decide it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MariaDB's default
SQL_MODE
depends on the MariaDB version and, even with current MariaDB versions, does not equalTRADITIONAL
.Furthermore, MySQL's default
SQL_MODE
is also something else.Fortunately, both MySQL's as well as MariaDB's
TRADITIONAL
mode maps to the same mode list.Thus, explicitly setting the
SQL_MODE
toTRADITIONAL
sounds like a good idea.