Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MySQL/MariaDB: Use strict SQL mode instead of just ANSI_QUOTES #624

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/config/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func (d *Database) Open(logger *logging.Logger) (*icingadb.DB, error) {

config.DBName = d.Database
config.Timeout = time.Minute
config.Params = map[string]string{"sql_mode": "ANSI_QUOTES"}
// Set strict SQL mode, i.e. trigger an error if an incorrect value is inserted into a column.
config.Params = map[string]string{"sql_mode": "TRADITIONAL"}
Comment on lines +64 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do differently compared to not setting sql_mode here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to explicitly enable strict mode instead of letting something else decide it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MariaDB's default SQL_MODE depends on the MariaDB version and, even with current MariaDB versions, does not equal TRADITIONAL.

MariaDB [icingadb]> SELECT @@SQL_MODE, @@GLOBAL.SQL_MODE\G
*************************** 1. row ***************************
       @@SQL_MODE: STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION
@@GLOBAL.SQL_MODE: STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION
1 row in set (0.000 sec)

MariaDB [icingadb]> SET sql_mode = 'TRADITIONAL';
Query OK, 0 rows affected (0.001 sec)

MariaDB [icingadb]> SELECT @@SQL_MODE, @@GLOBAL.SQL_MODE\G
*************************** 1. row ***************************
       @@SQL_MODE: STRICT_TRANS_TABLES,STRICT_ALL_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,TRADITIONAL,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION
@@GLOBAL.SQL_MODE: STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION
1 row in set (0.000 sec)

Furthermore, MySQL's default SQL_MODE is also something else.

Fortunately, both MySQL's as well as MariaDB's TRADITIONAL mode maps to the same mode list.
Thus, explicitly setting the SQL_MODE to TRADITIONAL sounds like a good idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 1.1.1 we should just set TRADITIONAL,ANSI_QUOTES here. This is a bugfix release. PoC:

MariaDB [idb]> SET SQL_MODE='TRADITIONAL';
Query OK, 0 rows affected (0,003 sec)

MariaDB [idb]> SET SQL_MODE='TRADITIONAL,ANSI_QUOTES';
Query OK, 0 rows affected (0,003 sec)

MariaDB [idb]> select @@SQL_MODE\G
*************************** 1. row ***************************
@@SQL_MODE: ANSI_QUOTES,STRICT_TRANS_TABLES,STRICT_ALL_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,TRADITIONAL,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION
1 row in set (0,003 sec)

MariaDB [idb]> SET SQL_MODE='TRADITIONAL,ANSI_QUOTES,LOLCAT';
ERROR 1231 (42000): Variable 'sql_mode' can't be set to the value of 'LOLCAT'
MariaDB [idb]>


tlsConfig, err := d.TlsOptions.MakeConfig(d.Host)
if err != nil {
Expand Down
61 changes: 26 additions & 35 deletions pkg/icingadb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type DB struct {
logger *logging.Logger
tableSemaphores map[string]*semaphore.Weighted
tableSemaphoresMu sync.Mutex
quoter *Quoter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering: why hide this inside a member instead of having the functions directly available on DB (maybe by embedding the struct/pointer)? Would make the calls a bit nicer and the functions wouldn't look out of place there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't think of that tbh.

}

// Options define user configurable database options.
Expand Down Expand Up @@ -81,6 +82,7 @@ func NewDb(db *sqlx.DB, logger *logging.Logger, options *Options) *DB {
logger: logger,
Options: options,
tableSemaphores: make(map[string]*semaphore.Weighted),
quoter: NewQuoter(db),
}
}

Expand Down Expand Up @@ -136,8 +138,8 @@ func (db *DB) BuildColumns(subject interface{}) []string {
// BuildDeleteStmt returns a DELETE statement for the given struct.
func (db *DB) BuildDeleteStmt(from interface{}) string {
return fmt.Sprintf(
`DELETE FROM "%s" WHERE id IN (?)`,
utils.TableName(from),
"DELETE FROM %s WHERE id IN (?)",
db.quoter.QuoteIdentifier(utils.TableName(from)),
)
}

Expand All @@ -146,9 +148,9 @@ func (db *DB) BuildInsertStmt(into interface{}) (string, int) {
columns := db.BuildColumns(into)

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s)`,
utils.TableName(into),
strings.Join(columns, `", "`),
"INSERT INTO %s (%s) VALUES (%s)",
db.quoter.QuoteIdentifier(utils.TableName(into)),
db.quoter.QuoteColumnList(columns),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
), len(columns)
}
Expand All @@ -163,15 +165,15 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
switch db.DriverName() {
case driver.MySQL:
// MySQL treats UPDATE id = id as a no-op.
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0])
clause = fmt.Sprintf("ON DUPLICATE KEY UPDATE %[1]s = %[1]s", db.quoter.QuoteIdentifier(columns[0]))
case driver.PostgreSQL:
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table)
}

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
table,
strings.Join(columns, `", "`),
"INSERT INTO %s (%s) VALUES (%s) %s",
db.quoter.QuoteIdentifier(table),
db.quoter.QuoteColumnList(columns),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
clause,
), len(columns)
Expand All @@ -181,14 +183,14 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
// and the column list from the specified columns struct.
func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
q := fmt.Sprintf(
`SELECT "%s" FROM "%s"`,
strings.Join(db.BuildColumns(columns), `", "`),
utils.TableName(table),
"SELECT %s FROM %s",
db.quoter.QuoteColumnList(db.BuildColumns(columns)),
db.quoter.QuoteIdentifier(utils.TableName(table)),
)

if scoper, ok := table.(contracts.Scoper); ok {
where, _ := db.BuildWhere(scoper.Scope())
q += ` WHERE ` + where
q += " WHERE " + where
}

return q
Expand All @@ -197,16 +199,11 @@ func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
// BuildUpdateStmt returns an UPDATE statement for the given struct.
func (db *DB) BuildUpdateStmt(update interface{}) (string, int) {
columns := db.BuildColumns(update)
set := make([]string, 0, len(columns))

for _, col := range columns {
set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col))
}

return fmt.Sprintf(
`UPDATE "%s" SET %s WHERE id = :id`,
utils.TableName(update),
strings.Join(set, ", "),
"UPDATE %s SET %s WHERE id = :id",
db.quoter.QuoteIdentifier(utils.TableName(update)),
strings.Join(db.quoter.BuildAssignmentList(columns), ", "),
), len(columns) + 1 // +1 because of WHERE id = :id
}

Expand All @@ -226,38 +223,32 @@ func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders in
switch db.DriverName() {
case driver.MySQL:
clause = "ON DUPLICATE KEY UPDATE"
setFormat = `"%[1]s" = VALUES("%[1]s")`
setFormat = fmt.Sprintf("%[1]s = VALUES(%[1]s)", db.quoter.QuoteIdentifier("%[1]s"))
case driver.PostgreSQL:
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table)
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
setFormat = fmt.Sprintf("%[1]s = EXCLUDED.%[1]s", db.quoter.QuoteIdentifier("%[1]s"))
}

set := make([]string, 0, len(updateColumns))

for _, col := range updateColumns {
set = append(set, fmt.Sprintf(setFormat, col))
}

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
table,
strings.Join(insertColumns, `", "`),
fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")),
"INSERT INTO %s (%s) VALUES (%s) %s %s",
db.quoter.QuoteIdentifier(table),
db.quoter.QuoteColumnList(insertColumns),
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")),
clause,
strings.Join(set, ","),
strings.Join(set, ", "),
), len(insertColumns)
}

// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct
// combined with the AND operator.
func (db *DB) BuildWhere(subject interface{}) (string, int) {
columns := db.BuildColumns(subject)
where := make([]string, 0, len(columns))
for _, col := range columns {
where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col))
}

return strings.Join(where, ` AND `), len(columns)
return strings.Join(db.quoter.BuildAssignmentList(columns), " AND "), len(columns)
}

// OnSuccess is a callback for successful (bulk) DML operations.
Expand Down
54 changes: 54 additions & 0 deletions pkg/icingadb/quoter.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a few tests for this file should be easy by initializing with &Quoter{"`"}. Testing NewQuoter() would probably be annoying without a database connection.

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package icingadb

import (
"fmt"
"github.com/icinga/icingadb/pkg/driver"
"github.com/jmoiron/sqlx"
"strings"
)

// Quoter provides utility functions for quoting table names and columns,
// where the quote character depends on the database driver used.
type Quoter struct {
quoteCharacter string
}

// NewQuoter creates and returns a new Quoter
// carrying the quote character appropriate for the given database connection.
func NewQuoter(db *sqlx.DB) *Quoter {
var qc string

switch db.DriverName() {
case driver.MySQL:
qc = "`"
case driver.PostgreSQL:
qc = `"`
default:
panic("unknown driver " + db.DriverName())
}

return &Quoter{quoteCharacter: qc}
}

// BuildAssignmentList quotes the specified columns into `column = :column` pairs for safe use in named query parts,
// i.e. `UPDATE ... SET assignment_list` and `SELECT ... WHERE where_condition`.
func (q *Quoter) BuildAssignmentList(columns []string) []string {
assign := make([]string, 0, len(columns))
for _, col := range columns {
assign = append(assign, fmt.Sprintf("%s = :%s", q.QuoteIdentifier(col), col))
}

return assign
}

// QuoteColumnList quotes the given columns into a single comma concatenated string
// so that they can be safely used as a column list for SELECT and INSERT statements.
func (q *Quoter) QuoteColumnList(columns []string) string {
return fmt.Sprintf("%[1]s%s%[1]s", q.quoteCharacter, strings.Join(columns, q.quoteCharacter+", "+q.quoteCharacter))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing numbered and unnumbered placeholders looks strange. At least I can't tell without double-checking which argument an unnumbered placeholder would refer to after some numbered ones.

So maybe at least go with %[1]s%[2]s%[1]s here.

But it's probably easier to not use printf here at all, q.quoteCharacter + strings.Join(columns, q.quoteCharacter+", "+q.quoteCharacter)) + q.quoteCharacter is shorter, even though it repeats the name quoteCharacter yet another time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can do that.

}

// QuoteIdentifier quotes the given identifier so that it can be safely used as table or column name,
// even if it is a reserved name where the quote character depends on the database driver used.
func (q *Quoter) QuoteIdentifier(identifier string) string {
return q.quoteCharacter + identifier + q.quoteCharacter
}