Skip to content

Commit

Permalink
feat: implement an JSON and Text marshal and unmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
josestg committed Mar 11, 2023
1 parent c7e2a20 commit d040942
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
36 changes: 36 additions & 0 deletions objectid.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -92,6 +94,40 @@ func (id ID) Count() uint32 {
// IsZero returns true if the ObjectID is the Nil value.
func (id ID) IsZero() bool { return id == Nil }

// MarshalText implements the encoding.TextMarshaler interface.
// This is useful when using the ID as a map key during JSON marshalling.
func (id ID) MarshalText() ([]byte, error) {
return []byte(id.String()), nil
}

// UnmarshalText implements the encoding.TextUnmarshaler interface.
// This is useful when using the ID as a map key during JSON unmarshalling.
func (id *ID) UnmarshalText(b []byte) error {
decoded, err := Decode(string(b))
if err != nil {
return err
}
*id = decoded
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (id ID) MarshalJSON() ([]byte, error) {
return json.Marshal(id.String())
}

// UnmarshalJSON implements json.Unmarshaler.
func (id *ID) UnmarshalJSON(data []byte) error {
// remove the surrounding quotes from the JSON string
str := strings.Trim(string(data), "\"")
decoded, err := Decode(str)
if err != nil {
return err
}
*id = decoded
return nil
}

// Decode decodes the string representation and returns the corresponding ID.
func Decode(s string) (ID, error) {
if len(s) != 24 {
Expand Down
97 changes: 97 additions & 0 deletions objectid_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package objectid_test

import (
"bytes"
"encoding/json"
"github.com/pkg-id/objectid"
"strings"
"testing"
Expand Down Expand Up @@ -75,3 +77,98 @@ func TestDecode(t *testing.T) {
t.Fatalf("expec id must be a zero value")
}
}

func TestID_MarshalText(t *testing.T) {
// Test valid input
id := objectid.New()
expectedOutput := []byte(id.String())
output, err := id.MarshalText()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !bytes.Equal(expectedOutput, output) {
t.Errorf("Expected %v but got %v", expectedOutput, output)
}

// Test invalid input
id = objectid.ID{}
expectedOutput = []byte(id.String())
output, err = id.MarshalText()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !bytes.Equal(expectedOutput, output) {
t.Errorf("Expected %v but got %v", expectedOutput, output)
}
}

func TestID_UnmarshalText(t *testing.T) {
// Test valid input
id := objectid.New()
b := []byte(id.String())
expectedOutput := &id
err := expectedOutput.UnmarshalText(b)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if *expectedOutput != id {
t.Errorf("Expected %v but got %v", id, *expectedOutput)
}

// Test invalid input
b = []byte("invalid")
expectedOutput = &id
err = expectedOutput.UnmarshalText(b)
if err == nil {
t.Errorf("Expected error but got nil")
}
}

func TestID_MarshalJSON(t *testing.T) {
// Test valid input
id := objectid.New()
expectedOutput, err := json.Marshal(id)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
output, err := id.MarshalJSON()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !bytes.Equal(expectedOutput, output) {
t.Errorf("Expected %v but got %v", expectedOutput, output)
}

// Test invalid input
id = objectid.ID{}
expectedOutput, err = json.Marshal(id)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
output, err = id.MarshalJSON()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !bytes.Equal(expectedOutput, output) {
t.Errorf("Expected %v but got %v", expectedOutput, output)
}
}

func TestID_UnmarshalJSON(t *testing.T) {
type Data struct {
ID objectid.ID `json:"id"`
}

rawJSON := `{"id":"640c5fe5d243553cda8dde1b"}`

var obj Data
err := json.Unmarshal([]byte(rawJSON), &obj)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

expectedID := "640c5fe5d243553cda8dde1b"
if obj.ID.String() != expectedID {
t.Errorf("Unexpected ID value: got %s, expected %s", obj.ID, expectedID)
}
}

0 comments on commit d040942

Please sign in to comment.