Skip to content

Commit

Permalink
feat: implement driver.Valuer and sql.Scanner (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
josestg authored Apr 18, 2023
1 parent d040942 commit 6ff492d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
28 changes: 28 additions & 0 deletions objectid.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package objectid

import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -128,6 +129,33 @@ func (id *ID) UnmarshalJSON(data []byte) error {
return nil
}

// Value implements the driver.Valuer.
func (id ID) Value() (driver.Value, error) {
return driver.Value(id.String()), nil
}

// Scan implements the sql.Scanner
func (id *ID) Scan(src any) error {
if src == nil {
*id = Nil
return nil
}

var s string
switch t := src.(type) {
default:
return fmt.Errorf("objectid: scan: unsuported source type: %T", t)
case []byte:
s = string(t)
case string:
s = t
}

oid, err := Decode(s)
*id = oid
return err
}

// Decode decodes the string representation and returns the corresponding ID.
func Decode(s string) (ID, error) {
if len(s) != 24 {
Expand Down
69 changes: 69 additions & 0 deletions objectid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,72 @@ func TestID_UnmarshalJSON(t *testing.T) {
t.Errorf("Unexpected ID value: got %s, expected %s", obj.ID, expectedID)
}
}

func TestID_Value(t *testing.T) {
var id objectid.ID
value, err := id.Value()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

vid := value.(string)
if vid != id.String() {
t.Errorf("Unexpected ID value: got %s, expected %s", vid, id)
}
}

func TestID_Scan(t *testing.T) {

t.Run("invalid src type", func(t *testing.T) {
var id objectid.ID
err := id.Scan(1)
if err == nil {
t.Errorf("Unexpected error: %v", err)
}

if id != objectid.Nil {
t.Errorf("Unexpected ID value: got %s, expected %s", id, objectid.Nil)
}
})

t.Run("invalid format", func(t *testing.T) {
var id objectid.ID
err := id.Scan("xxx-xxx-xx")
if err == nil {
t.Errorf("Unexpected error: %v", err)
}

if id != objectid.Nil {
t.Errorf("Unexpected ID value: got %s, expected %s", id, objectid.Nil)
}
})

t.Run("from bytes", func(t *testing.T) {
var id objectid.ID

sid := objectid.New()
src := []byte(sid.String())
err := id.Scan(src)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if id != sid {
t.Errorf("Unexpected ID value: got %s, expected %s", id, sid)
}
})

t.Run("from string", func(t *testing.T) {
var id objectid.ID

sid := objectid.New()
err := id.Scan(sid.String())
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if id != sid {
t.Errorf("Unexpected ID value: got %s, expected %s", id, sid)
}
})
}

0 comments on commit 6ff492d

Please sign in to comment.