diff --git a/migrator.go b/migrator.go index 3f80515..360597d 100644 --- a/migrator.go +++ b/migrator.go @@ -215,9 +215,20 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { rows.Close() { + _, schemaName, tableName := splitFullQualifiedName(stmt.Table) + + query := "SELECT COLUMN_NAME, DATA_TYPE, COLUMN_DEFAULT, IS_NULLABLE, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_PRECISION_RADIX, NUMERIC_SCALE, DATETIME_PRECISION FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?" + + queryParameters := []interface{}{m.CurrentDatabase(), tableName} + + if schemaName != "" { + query += " AND TABLE_SCHEMA = ?" + queryParameters = append(queryParameters, schemaName) + } + var ( - columnTypeSQL = "SELECT COLUMN_NAME, DATA_TYPE, COLUMN_DEFAULT, IS_NULLABLE, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_PRECISION_RADIX, NUMERIC_SCALE, DATETIME_PRECISION FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?" - columns, rowErr = m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows() + columnTypeSQL = query + columns, rowErr = m.DB.Raw(columnTypeSQL, queryParameters...).Rows() ) if rowErr != nil { @@ -272,7 +283,17 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } { - columnTypeRows, err := m.DB.Raw("SELECT c.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c ON c.CONSTRAINT_NAME=t.CONSTRAINT_NAME WHERE t.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'UNIQUE') AND c.TABLE_CATALOG = ? AND c.TABLE_NAME = ?", m.CurrentDatabase(), stmt.Table).Rows() + _, schemaName, tableName := splitFullQualifiedName(stmt.Table) + query := "SELECT c.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c ON c.CONSTRAINT_NAME=t.CONSTRAINT_NAME WHERE t.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'UNIQUE') AND c.TABLE_CATALOG = ? AND c.TABLE_NAME = ?" + + queryParameters := []interface{}{m.CurrentDatabase(), tableName} + + if schemaName != "" { + query += " AND c.TABLE_SCHEMA = ?" + queryParameters = append(queryParameters, schemaName) + } + + columnTypeRows, err := m.DB.Raw(query, queryParameters...).Rows() if err != nil { return err }