diff --git a/database/contracts.go b/database/contracts.go index 7ab8cd19..96d480f6 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -47,3 +47,9 @@ type TableNamer interface { type Scoper interface { Scope() any } + +// Constrainter implements the Constraint method, +// which returns the PK or unique key constraint name of the table. +type Constrainter interface { + Constraint() string // Constraint returns the PK/UNIQUE key constraint name +} diff --git a/database/db.go b/database/db.go index 296da23d..773d1c01 100644 --- a/database/db.go +++ b/database/db.go @@ -232,7 +232,14 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { // MySQL treats UPDATE id = id as a no-op. clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table) + var constraint string + if constrainter, ok := into.(Constrainter); ok { + constraint = constrainter.Constraint() + } else { + constraint = "pk_" + table + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint) } return fmt.Sprintf( @@ -295,7 +302,14 @@ func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders in clause = "ON DUPLICATE KEY UPDATE" setFormat = `"%[1]s" = VALUES("%[1]s")` case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table) + var constraint string + if constrainter, ok := subject.(Constrainter); ok { + constraint = constrainter.Constraint() + } else { + constraint = "pk_" + table + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint) setFormat = `"%[1]s" = EXCLUDED."%[1]s"` }