From d040942d0387d5bcedc6303d0fa11de2fd93a5b8 Mon Sep 17 00:00:00 2001 From: Jose Sitanggang Date: Sat, 11 Mar 2023 19:57:38 +0700 Subject: [PATCH] feat: implement an JSON and Text marshal and unmarshal --- objectid.go | 36 ++++++++++++++++++ objectid_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/objectid.go b/objectid.go index cfe7972..3f55b67 100644 --- a/objectid.go +++ b/objectid.go @@ -6,9 +6,11 @@ import ( "crypto/rand" "encoding/binary" "encoding/hex" + "encoding/json" "errors" "fmt" "io" + "strings" "sync" "sync/atomic" "time" @@ -92,6 +94,40 @@ func (id ID) Count() uint32 { // IsZero returns true if the ObjectID is the Nil value. func (id ID) IsZero() bool { return id == Nil } +// MarshalText implements the encoding.TextMarshaler interface. +// This is useful when using the ID as a map key during JSON marshalling. +func (id ID) MarshalText() ([]byte, error) { + return []byte(id.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// This is useful when using the ID as a map key during JSON unmarshalling. +func (id *ID) UnmarshalText(b []byte) error { + decoded, err := Decode(string(b)) + if err != nil { + return err + } + *id = decoded + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (id ID) MarshalJSON() ([]byte, error) { + return json.Marshal(id.String()) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (id *ID) UnmarshalJSON(data []byte) error { + // remove the surrounding quotes from the JSON string + str := strings.Trim(string(data), "\"") + decoded, err := Decode(str) + if err != nil { + return err + } + *id = decoded + return nil +} + // Decode decodes the string representation and returns the corresponding ID. func Decode(s string) (ID, error) { if len(s) != 24 { diff --git a/objectid_test.go b/objectid_test.go index cc2a505..2f7e21b 100644 --- a/objectid_test.go +++ b/objectid_test.go @@ -1,6 +1,8 @@ package objectid_test import ( + "bytes" + "encoding/json" "github.com/pkg-id/objectid" "strings" "testing" @@ -75,3 +77,98 @@ func TestDecode(t *testing.T) { t.Fatalf("expec id must be a zero value") } } + +func TestID_MarshalText(t *testing.T) { + // Test valid input + id := objectid.New() + expectedOutput := []byte(id.String()) + output, err := id.MarshalText() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(expectedOutput, output) { + t.Errorf("Expected %v but got %v", expectedOutput, output) + } + + // Test invalid input + id = objectid.ID{} + expectedOutput = []byte(id.String()) + output, err = id.MarshalText() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(expectedOutput, output) { + t.Errorf("Expected %v but got %v", expectedOutput, output) + } +} + +func TestID_UnmarshalText(t *testing.T) { + // Test valid input + id := objectid.New() + b := []byte(id.String()) + expectedOutput := &id + err := expectedOutput.UnmarshalText(b) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if *expectedOutput != id { + t.Errorf("Expected %v but got %v", id, *expectedOutput) + } + + // Test invalid input + b = []byte("invalid") + expectedOutput = &id + err = expectedOutput.UnmarshalText(b) + if err == nil { + t.Errorf("Expected error but got nil") + } +} + +func TestID_MarshalJSON(t *testing.T) { + // Test valid input + id := objectid.New() + expectedOutput, err := json.Marshal(id) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + output, err := id.MarshalJSON() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(expectedOutput, output) { + t.Errorf("Expected %v but got %v", expectedOutput, output) + } + + // Test invalid input + id = objectid.ID{} + expectedOutput, err = json.Marshal(id) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + output, err = id.MarshalJSON() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(expectedOutput, output) { + t.Errorf("Expected %v but got %v", expectedOutput, output) + } +} + +func TestID_UnmarshalJSON(t *testing.T) { + type Data struct { + ID objectid.ID `json:"id"` + } + + rawJSON := `{"id":"640c5fe5d243553cda8dde1b"}` + + var obj Data + err := json.Unmarshal([]byte(rawJSON), &obj) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expectedID := "640c5fe5d243553cda8dde1b" + if obj.ID.String() != expectedID { + t.Errorf("Unexpected ID value: got %s, expected %s", obj.ID, expectedID) + } +}