Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support slice of custom types in expandSliceArgs #427

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 28 additions & 66 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,8 @@ type numberer interface {
}

func expandSliceArgs(query *string, args ...interface{}) {
valuerType := reflect.TypeOf((*driver.Valuer)(nil)).Elem()

for _, arg := range args {
mapper, ok := arg.(map[string]interface{})
if !ok {
Expand All @@ -905,76 +907,36 @@ func expandSliceArgs(query *string, args ...interface{}) {
value = v.ToInt64Slice()
}

switch v := value.(type) {
case []string:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint8:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint16:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int8:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int16:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []float32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []float64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
t := reflect.TypeOf(value)
if t.Kind() != reflect.Slice || t.Implements(valuerType) {
// Do not expand if the value is not a slice or implements driver.Valuer,
continue
}
elm := t.Elem()
// If the element of slice implements driver.Valuer or is a primitive value,
// expand the slice
isValue := elm.Implements(valuerType)
if !isValue {
switch elm.Kind() {
case reflect.String,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64:
isValue = true
}
default:
}
if !isValue {
continue
}

val := reflect.ValueOf(value)
l := val.Len()
for id := 0; id < l; id++ {
v := val.Index(id)
mapper[fmt.Sprintf("%s%d", key, id)] = v.Interface()
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}

if len(replacements) == 0 {
continue
}
Expand Down
84 changes: 70 additions & 14 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package gorp_test

import (
"database/sql/driver"
"strings"
"testing"
)

Expand All @@ -22,6 +24,21 @@ func (c customType2) ToInt64Slice() []int64 {
return []int64(c)
}

type valuerSlice []string

func (vs valuerSlice) Value() (driver.Value, error) {
return strings.Join(vs, ","), nil
}

func (vs *valuerSlice) Scan(val interface{}) error {
*vs = strings.Split(string(val.([]byte)), ",")
return nil
}

var _ driver.Valuer = valuerSlice([]string{})

type customID int64

func TestDbMap_Select_expandSliceArgs(t *testing.T) {
tests := []struct {
description string
Expand Down Expand Up @@ -83,23 +100,51 @@ AND field12 IN (:FieldIntList)
},
wantLen: 3,
},
{
description: "handle customID types",
query: `
SELECT 1 FROM crazy_table
WHERE field16 IN (:FieldCustomIDList)
`,
args: []interface{}{
map[string]interface{}{
"FieldCustomIDList": []customID{3, 4, 5},
},
},
wantLen: 2,
},
{
description: "handle types which are sql.Valuer",
query: `
SELECT 1 FROM crazy_table
WHERE field15 = :FieldCustomValuer
`,
args: []interface{}{
map[string]interface{}{
"FieldCustomValuer": valuerSlice([]string{"aaa", "bbb"}),
},
},
wantLen: 1,
},
}

type dataFormat struct {
Field1 int `db:"field1"`
Field2 string `db:"field2"`
Field3 uint `db:"field3"`
Field4 uint8 `db:"field4"`
Field5 uint16 `db:"field5"`
Field6 uint32 `db:"field6"`
Field7 uint64 `db:"field7"`
Field8 int `db:"field8"`
Field9 int8 `db:"field9"`
Field10 int16 `db:"field10"`
Field11 int32 `db:"field11"`
Field12 int64 `db:"field12"`
Field13 float32 `db:"field13"`
Field14 float64 `db:"field14"`
Field1 int `db:"field1"`
Field2 string `db:"field2"`
Field3 uint `db:"field3"`
Field4 uint8 `db:"field4"`
Field5 uint16 `db:"field5"`
Field6 uint32 `db:"field6"`
Field7 uint64 `db:"field7"`
Field8 int `db:"field8"`
Field9 int8 `db:"field9"`
Field10 int16 `db:"field10"`
Field11 int32 `db:"field11"`
Field12 int64 `db:"field12"`
Field13 float32 `db:"field13"`
Field14 float64 `db:"field14"`
Field15 valuerSlice `db:"field15"`
Field16 customID `db:"field16"`
}

dbmap := newDbMap()
Expand Down Expand Up @@ -161,6 +206,17 @@ AND field12 IN (:FieldIntList)
Field13: 3,
Field14: 3,
},
&dataFormat{
Field1: 126,
Field2: "h",
Field15: []string{"aaa", "bbb"},
Field16: customID(4),
},
&dataFormat{
Field1: 127,
Field2: "o",
Field16: customID(5),
},
)

if err != nil {
Expand Down