Skip to content

Commit

Permalink
update ToSql
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl committed Sep 24, 2024
1 parent 3467948 commit 1ddb78a
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 116 deletions.
4 changes: 2 additions & 2 deletions contracts/database/orm/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type Event interface {
GetAttribute(key string) any
// GetOriginal returns the original attribute value for the given key.
GetOriginal(key string, def ...any) any
// IsDirty returns true if the given column is dirty.
IsDirty(columns ...string) bool
// IsClean returns true if the given column is clean.
IsClean(columns ...string) bool
// IsDirty returns true if the given column is dirty.
IsDirty(columns ...string) bool
// Query returns the query instance.
Query() Query
// SetAttribute sets the attribute value for the given key.
Expand Down
3 changes: 2 additions & 1 deletion contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ type Result struct {
type ToSql interface {
Count() string
Create(value any) string
Delete(value any, conds ...any) string
Delete(value ...any) string
Find(dest any, conds ...any) string
First(dest any) string
ForceDelete(value ...any) string
Get(dest any) string
Pluck(column string, dest any) string
Save(value any) string
Expand Down
192 changes: 96 additions & 96 deletions database/gorm/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,53 +28,12 @@ func NewEvent(query *QueryImpl, model, dest any) *Event {
}
}

func (e *Event) ColumnNames() map[string]string {
if e.columnNames != nil {
return e.columnNames
}

if e.model != nil {
return fetchColumnNames(e.model)
} else {
return fetchColumnNames(e.dest)
}
}

func (e *Event) Context() context.Context {
return e.query.ctx
}

func (e *Event) DestOfMap() map[string]any {
if e.dest == nil {
return nil
}
if e.destOfMap != nil {
return e.destOfMap
}

destOfMap := make(map[string]any)
if destMap, ok := e.dest.(map[string]any); ok {
for key, value := range destMap {
destOfMap[key] = value
destOfMap[str.Of(key).Snake().String()] = value
}
} else {
destType := reflect.TypeOf(e.dest)
if destType.Kind() == reflect.Pointer {
destType = destType.Elem()
}
if destType.Kind() == reflect.Struct {
destOfMap = structToMap(e.dest)
}
}

e.destOfMap = destOfMap

return e.destOfMap
}

func (e *Event) GetAttribute(key string) any {
destOfMap := e.DestOfMap()
destOfMap := e.getDestOfMap()
value, exist := destOfMap[e.toDBColumnName(key)]
if exist && e.validColumn(key) && e.validValue(key, value) {
return value
Expand All @@ -84,7 +43,7 @@ func (e *Event) GetAttribute(key string) any {
}

func (e *Event) GetOriginal(key string, def ...any) any {
modelOfMap := e.ModelOfMap()
modelOfMap := e.getModelOfMap()
value, exist := modelOfMap[e.toDBColumnName(key)]
if exist {
return value
Expand All @@ -97,8 +56,12 @@ func (e *Event) GetOriginal(key string, def ...any) any {
return nil
}

func (e *Event) IsClean(fields ...string) bool {
return !e.IsDirty(fields...)
}

func (e *Event) IsDirty(columns ...string) bool {
destOfMap := e.DestOfMap()
destOfMap := e.getDestOfMap()

if len(columns) == 0 {
for destColumn, destValue := range destOfMap {
Expand Down Expand Up @@ -128,24 +91,6 @@ func (e *Event) IsDirty(columns ...string) bool {
return false
}

func (e *Event) IsClean(fields ...string) bool {
return !e.IsDirty(fields...)
}

func (e *Event) ModelOfMap() map[string]any {
if e.modelOfMap != nil {
return e.modelOfMap
}

if e.model == nil {
return map[string]any{}
}

e.modelOfMap = structToMap(e.model)

return e.modelOfMap
}

func (e *Event) Query() orm.Query {
return NewQueryImpl(e.query.ctx, e.query.config, e.query.connection, e.query.instance.Session(&gorm.Session{NewDB: true}), nil)
}
Expand All @@ -155,7 +100,7 @@ func (e *Event) SetAttribute(key string, value any) {
return
}

destOfMap := e.DestOfMap()
destOfMap := e.getDestOfMap()
destOfMap[e.toDBColumnName(key)] = value

Check failure on line 104 in database/gorm/event.go

View workflow job for this annotation

GitHub Actions / lint / nilaway

error: Potential nil panic detected. Observed nil flow from source to dereference point:
e.destOfMap = destOfMap

Expand Down Expand Up @@ -194,7 +139,7 @@ func (e *Event) SetAttribute(key string, value any) {
}

func (e *Event) dirty(destColumn string, destValue any) bool {
modelOfMap := e.ModelOfMap()
modelOfMap := e.getModelOfMap()
dbDestColumn := e.toDBColumnName(destColumn)

if modelValue, exist := modelOfMap[dbDestColumn]; exist {
Expand All @@ -215,8 +160,63 @@ func (e *Event) equalColumnName(origin, source string) bool {
return originDbColumnName == sourceDbColumnName
}

func (e *Event) getColumnNames() map[string]string {
if e.columnNames == nil {
if e.model != nil {
e.columnNames = fetchColumnNames(e.model)
} else {
e.columnNames = fetchColumnNames(e.dest)
}
}

return e.columnNames
}

func (e *Event) getDestOfMap() map[string]any {
if e.dest == nil {
return nil
}
if e.destOfMap != nil {
return e.destOfMap
}

destOfMap := make(map[string]any)
if destMap, ok := e.dest.(map[string]any); ok {
for key, value := range destMap {
destOfMap[key] = value
destOfMap[str.Of(key).Snake().String()] = value
}
} else {
destType := reflect.TypeOf(e.dest)
if destType.Kind() == reflect.Pointer {
destType = destType.Elem()
}
if destType.Kind() == reflect.Struct {
destOfMap = structToMap(e.dest)
}
}

e.destOfMap = destOfMap

return e.destOfMap
}

func (e *Event) getModelOfMap() map[string]any {
if e.modelOfMap != nil {
return e.modelOfMap
}

if e.model == nil {
return map[string]any{}
}

e.modelOfMap = structToMap(e.model)

return e.modelOfMap
}

func (e *Event) toDBColumnName(name string) string {
dbColumnName, exist := e.ColumnNames()[name]
dbColumnName, exist := e.getColumnNames()[name]
if exist {
return dbColumnName
}
Expand Down Expand Up @@ -284,6 +284,37 @@ func (e *Event) validValue(name string, value any) bool {
return !valueValue.IsZero()
}

func fetchColumnNames(model any) map[string]string {
res := make(map[string]string)
modelType := reflect.TypeOf(model)
modelValue := reflect.ValueOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
modelValue = modelValue.Elem()
}

for i := 0; i < modelType.NumField(); i++ {
if !modelType.Field(i).IsExported() {
continue
}
fieldType := modelType.Field(i)
fieldValue := modelValue.Field(i)
if fieldValue.Kind() == reflect.Struct && fieldType.Anonymous {
subStructMap := fetchColumnNames(fieldValue.Interface())
for key, value := range subStructMap {
res[key] = value
}
continue
}

dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm"))
res[modelType.Field(i).Name] = dbColumn
res[dbColumn] = dbColumn
}

return res
}

func structToMap(data any) map[string]any {
res := make(map[string]any)
modelType := reflect.TypeOf(data)
Expand Down Expand Up @@ -341,34 +372,3 @@ func structNameToDbColumnName(structName, tag string) string {

return str.Of(structName).Snake().String()
}

func fetchColumnNames(model any) map[string]string {
res := make(map[string]string)
modelType := reflect.TypeOf(model)
modelValue := reflect.ValueOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
modelValue = modelValue.Elem()
}

for i := 0; i < modelType.NumField(); i++ {
if !modelType.Field(i).IsExported() {
continue
}
fieldType := modelType.Field(i)
fieldValue := modelValue.Field(i)
if fieldValue.Kind() == reflect.Struct && fieldType.Anonymous {
subStructMap := fetchColumnNames(fieldValue.Interface())
for key, value := range subStructMap {
res[key] = value
}
continue
}

dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm"))
res[modelType.Field(i).Name] = dbColumn
res[dbColumn] = dbColumn
}

return res
}
2 changes: 1 addition & 1 deletion database/gorm/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (s *EventTestSuite) TestColumnNames() {
"admin_at": "admin_at",
"ManageAt": "manage_at",
"manage_at": "manage_at",
}, event.ColumnNames())
}, event.getColumnNames())
}
}

Expand Down
20 changes: 18 additions & 2 deletions database/gorm/to_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ func (r *ToSql) Create(value any) string {
return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Create(value))
}

func (r *ToSql) Delete(value any, conds ...any) string {
func (r *ToSql) Delete(value ...any) string {
query := r.query.buildConditions()

return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Delete(value, conds...))
var dest any
if len(value) > 0 {
dest = value[0]
}

return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Delete(dest))
}

func (r *ToSql) Find(dest any, conds ...any) string {
Expand All @@ -47,6 +52,17 @@ func (r *ToSql) First(dest any) string {
return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).First(dest))
}

func (r *ToSql) ForceDelete(value ...any) string {
query := r.query.buildConditions()

var dest any
if len(value) > 0 {
dest = value[0]
}

return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Unscoped().Delete(dest))
}

func (r *ToSql) Get(dest any) string {
query := r.query.buildConditions()

Expand Down
28 changes: 14 additions & 14 deletions database/gorm/to_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,16 @@ func (s *ToSqlTestSuite) TestDelete() {
toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false)
s.Equal("UPDATE `users` SET `deleted_at`=? WHERE `id` = ? AND `users`.`deleted_at` IS NULL", toSql.Delete(&User{}))

toSql = NewToSql(s.query.(*QueryImpl), false)
s.Equal("UPDATE `users` SET `deleted_at`=? WHERE `users`.`id` = ? AND `users`.`deleted_at` IS NULL", toSql.Delete(&User{}, 1))

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), false)
s.Equal("DELETE FROM `roles` WHERE `id` = ?", toSql.Delete(&Role{}))

toSql = NewToSql(s.query.(*QueryImpl), false)
s.Equal("DELETE FROM `roles` WHERE `roles`.`id` = ?", toSql.Delete(&Role{}, 1))

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true)
sql := toSql.Delete(&User{})
s.Contains(sql, "UPDATE `users` SET `deleted_at`=")
s.Contains(sql, "WHERE `id` = 1 AND `users`.`deleted_at` IS NULL")

toSql = NewToSql(s.query.(*QueryImpl), true)
sql = toSql.Delete(&User{}, 1)
s.Contains(sql, "UPDATE `users` SET `deleted_at`=")
s.Contains(sql, "WHERE `users`.`id` = 1 AND `users`.`deleted_at` IS NULL")

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true)
s.Equal("DELETE FROM `roles` WHERE `id` = 1", toSql.Delete(&Role{}))

toSql = NewToSql(s.query.(*QueryImpl), true)
s.Equal("DELETE FROM `roles` WHERE `roles`.`id` = 1", toSql.Delete(&Role{}, 1))
}

func (s *ToSqlTestSuite) TestFind() {
Expand All @@ -108,6 +94,20 @@ func (s *ToSqlTestSuite) TestFirst() {
s.Equal("SELECT * FROM `users` WHERE `id` = 1 AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", toSql.First(&User{}))
}

func (s *ToSqlTestSuite) TestForceDelete() {
toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false)
s.Equal("DELETE FROM `users` WHERE `id` = ?", toSql.ForceDelete(&User{}))

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), false)
s.Equal("DELETE FROM `roles` WHERE `id` = ?", toSql.ForceDelete(&Role{}))

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true)
s.Equal("DELETE FROM `users` WHERE `id` = 1", toSql.ForceDelete(&User{}))

toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true)
s.Equal("DELETE FROM `roles` WHERE `id` = 1", toSql.ForceDelete(&Role{}))
}

func (s *ToSqlTestSuite) TestGet() {
toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false)
s.Equal("SELECT * FROM `users` WHERE `id` = ? AND `users`.`deleted_at` IS NULL", toSql.Get([]User{}))
Expand Down

0 comments on commit 1ddb78a

Please sign in to comment.