From 7e86cdc8d8abac4f6391525b9f92cde721329de8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrei=20B=C4=83ncioiu?= Date: Thu, 28 Mar 2024 20:31:38 +0200 Subject: [PATCH] Sketch ABI codec and serializer. --- abi/codec/constants.go | 5 + abi/codec/defaultCodec.go | 261 ++++++++ abi/codec/defaultCodecForCompositeTypes.go | 78 +++ abi/codec/defaultCodecForCustomTypes.go | 95 +++ abi/codec/defaultCodecForSimpleValues.go | 309 +++++++++ abi/codec/defaultCodec_test.go | 726 +++++++++++++++++++++ abi/codec/shared.go | 31 + abi/go.mod | 14 + abi/go.sum | 12 + abi/serializer/constants.go | 3 + abi/serializer/interface.go | 8 + abi/serializer/parts.go | 99 +++ abi/serializer/serializer.go | 224 +++++++ abi/serializer/serializer_test.go | 321 +++++++++ abi/values/values.go | 108 +++ 15 files changed, 2294 insertions(+) create mode 100644 abi/codec/constants.go create mode 100644 abi/codec/defaultCodec.go create mode 100644 abi/codec/defaultCodecForCompositeTypes.go create mode 100644 abi/codec/defaultCodecForCustomTypes.go create mode 100644 abi/codec/defaultCodecForSimpleValues.go create mode 100644 abi/codec/defaultCodec_test.go create mode 100644 abi/codec/shared.go create mode 100644 abi/go.mod create mode 100644 abi/go.sum create mode 100644 abi/serializer/constants.go create mode 100644 abi/serializer/interface.go create mode 100644 abi/serializer/parts.go create mode 100644 abi/serializer/serializer.go create mode 100644 abi/serializer/serializer_test.go create mode 100644 abi/values/values.go diff --git a/abi/codec/constants.go b/abi/codec/constants.go new file mode 100644 index 00000000..6a5c6bf0 --- /dev/null +++ b/abi/codec/constants.go @@ -0,0 +1,5 @@ +package codec + +const pubKeyLength = 32 +const trueAsByte = uint8(1) +const falseAsByte = uint8(0) diff --git a/abi/codec/defaultCodec.go b/abi/codec/defaultCodec.go new file mode 100644 index 00000000..0a7e7149 --- /dev/null +++ b/abi/codec/defaultCodec.go @@ -0,0 +1,261 @@ +package codec + +import ( + "bytes" + "fmt" + "io" + "math" + + "github.com/multiversx/mx-sdk-go/abi/values" +) + +// defaultCodec is the default codec for encoding and decoding values. +// +// See: +// - https://docs.multiversx.com/developers/data/simple-values +// - https://docs.multiversx.com/developers/data/composite-values +// - https://docs.multiversx.com/developers/data/custom-types +type defaultCodec struct { +} + +// NewDefaultCodec creates a new default codec. +func NewDefaultCodec() *defaultCodec { + return &defaultCodec{} +} + +func (c *defaultCodec) EncodeNested(value any) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + err := c.doEncodeNested(buffer, value) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func (c *defaultCodec) doEncodeNested(writer io.Writer, value any) error { + switch value.(type) { + case values.BoolValue: + return c.encodeNestedBool(writer, value.(values.BoolValue)) + case values.U8Value: + return c.encodeNestedNumber(writer, value.(values.U8Value).Value, 1) + case values.U16Value: + return c.encodeNestedNumber(writer, value.(values.U16Value).Value, 2) + case values.U32Value: + return c.encodeNestedNumber(writer, value.(values.U32Value).Value, 4) + case values.U64Value: + return c.encodeNestedNumber(writer, value.(values.U64Value).Value, 8) + case values.I8Value: + return c.encodeNestedNumber(writer, value.(values.I8Value).Value, 1) + case values.I16Value: + return c.encodeNestedNumber(writer, value.(values.I16Value).Value, 2) + case values.I32Value: + return c.encodeNestedNumber(writer, value.(values.I32Value).Value, 4) + case values.I64Value: + return c.encodeNestedNumber(writer, value.(values.I64Value).Value, 8) + case values.BigIntValue: + return c.encodeNestedBigNumber(writer, value.(values.BigIntValue).Value) + case values.AddressValue: + return c.encodeNestedAddress(writer, value.(values.AddressValue)) + case values.StringValue: + return c.encodeNestedString(writer, value.(values.StringValue)) + case values.BytesValue: + return c.encodeNestedBytes(writer, value.(values.BytesValue)) + case values.StructValue: + return c.encodeNestedStruct(writer, value.(values.StructValue)) + case values.EnumValue: + return c.encodeNestedEnum(writer, value.(values.EnumValue)) + case values.OptionValue: + return c.encodeNestedOption(writer, value.(values.OptionValue)) + case values.InputListValue: + return c.encodeNestedList(writer, value.(values.InputListValue)) + default: + return fmt.Errorf("unsupported type for nested encoding: %T", value) + } +} + +func (c *defaultCodec) EncodeTopLevel(value any) ([]byte, error) { + buffer := bytes.NewBuffer(nil) + err := c.doEncodeTopLevel(buffer, value) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func (c *defaultCodec) doEncodeTopLevel(writer io.Writer, value any) error { + switch value.(type) { + case values.BoolValue: + return c.encodeTopLevelBool(writer, value.(values.BoolValue)) + case values.U8Value: + return c.encodeTopLevelUnsignedNumber(writer, uint64(value.(values.U8Value).Value)) + case values.U16Value: + return c.encodeTopLevelUnsignedNumber(writer, uint64(value.(values.U16Value).Value)) + case values.U32Value: + return c.encodeTopLevelUnsignedNumber(writer, uint64(value.(values.U32Value).Value)) + case values.U64Value: + return c.encodeTopLevelUnsignedNumber(writer, value.(values.U64Value).Value) + case values.I8Value: + return c.encodeTopLevelSignedNumber(writer, int64(value.(values.I8Value).Value)) + case values.I16Value: + return c.encodeTopLevelSignedNumber(writer, int64(value.(values.I16Value).Value)) + case values.I32Value: + return c.encodeTopLevelSignedNumber(writer, int64(value.(values.I32Value).Value)) + case values.I64Value: + return c.encodeTopLevelSignedNumber(writer, value.(values.I64Value).Value) + case values.BigIntValue: + return c.encodeTopLevelBigNumber(writer, value.(values.BigIntValue).Value) + case values.AddressValue: + return c.encodeTopLevelAddress(writer, value.(values.AddressValue)) + case values.StructValue: + return c.encodeTopLevelStruct(writer, value.(values.StructValue)) + case values.EnumValue: + return c.encodeTopLevelEnum(writer, value.(values.EnumValue)) + default: + return fmt.Errorf("unsupported type for top-level encoding: %T", value) + } +} + +func (c *defaultCodec) DecodeNested(data []byte, value any) error { + reader := bytes.NewReader(data) + err := c.doDecodeNested(reader, value) + if err != nil { + return fmt.Errorf("cannot decode (nested) %T, because of: %w", value, err) + } + + return nil +} + +func (c *defaultCodec) doDecodeNested(reader io.Reader, value any) error { + switch value.(type) { + case *values.BoolValue: + return c.decodeNestedBool(reader, value.(*values.BoolValue)) + case *values.U8Value: + return c.decodeNestedNumber(reader, &value.(*values.U8Value).Value, 1) + case *values.U16Value: + return c.decodeNestedNumber(reader, &value.(*values.U16Value).Value, 2) + case *values.U32Value: + return c.decodeNestedNumber(reader, &value.(*values.U32Value).Value, 4) + case *values.U64Value: + return c.decodeNestedNumber(reader, &value.(*values.U64Value).Value, 8) + case *values.I8Value: + return c.decodeNestedNumber(reader, &value.(*values.I8Value).Value, 1) + case *values.I16Value: + return c.decodeNestedNumber(reader, &value.(*values.I16Value).Value, 2) + case *values.I32Value: + return c.decodeNestedNumber(reader, &value.(*values.I32Value).Value, 4) + case *values.I64Value: + return c.decodeNestedNumber(reader, &value.(*values.I64Value).Value, 8) + case *values.BigIntValue: + n, err := c.decodeNestedBigNumber(reader) + if err != nil { + return err + } + + value.(*values.BigIntValue).Value = n + return nil + case *values.AddressValue: + return c.decodeNestedAddress(reader, value.(*values.AddressValue)) + case *values.StringValue: + return c.decodeNestedString(reader, value.(*values.StringValue)) + case *values.BytesValue: + return c.decodeNestedBytes(reader, value.(*values.BytesValue)) + case *values.StructValue: + return c.decodeNestedStruct(reader, value.(*values.StructValue)) + case *values.EnumValue: + return c.decodeNestedEnum(reader, value.(*values.EnumValue)) + case *values.OptionValue: + return c.decodeNestedOption(reader, value.(*values.OptionValue)) + case *values.OutputListValue: + return c.decodeNestedList(reader, value.(*values.OutputListValue)) + default: + return fmt.Errorf("unsupported type for nested decoding: %T", value) + } +} + +func (c *defaultCodec) DecodeTopLevel(data []byte, value any) error { + err := c.doDecodeTopLevel(data, value) + if err != nil { + return fmt.Errorf("cannot decode (top-level) %T, because of: %w", value, err) + } + + return nil +} + +func (c *defaultCodec) doDecodeTopLevel(data []byte, value any) error { + switch value.(type) { + case *values.BoolValue: + return c.decodeTopLevelBool(data, value.(*values.BoolValue)) + case *values.U8Value: + n, err := c.decodeTopLevelUnsignedNumber(data, math.MaxUint8) + if err != nil { + return err + } + + value.(*values.U8Value).Value = uint8(n) + case *values.U16Value: + n, err := c.decodeTopLevelUnsignedNumber(data, math.MaxUint16) + if err != nil { + return err + } + + value.(*values.U16Value).Value = uint16(n) + case *values.U32Value: + n, err := c.decodeTopLevelUnsignedNumber(data, math.MaxUint32) + if err != nil { + return err + } + + value.(*values.U32Value).Value = uint32(n) + case *values.U64Value: + n, err := c.decodeTopLevelUnsignedNumber(data, math.MaxUint64) + if err != nil { + return err + } + + value.(*values.U64Value).Value = uint64(n) + case *values.I8Value: + n, err := c.decodeTopLevelSignedNumber(data, math.MaxInt8) + if err != nil { + return err + } + + value.(*values.I8Value).Value = int8(n) + case *values.I16Value: + n, err := c.decodeTopLevelSignedNumber(data, math.MaxInt16) + if err != nil { + return err + } + + value.(*values.I16Value).Value = int16(n) + case *values.I32Value: + n, err := c.decodeTopLevelSignedNumber(data, math.MaxInt32) + if err != nil { + return err + } + + value.(*values.I32Value).Value = int32(n) + + case *values.I64Value: + n, err := c.decodeTopLevelSignedNumber(data, math.MaxInt64) + if err != nil { + return err + } + + value.(*values.I64Value).Value = int64(n) + case *values.BigIntValue: + n := c.decodeTopLevelBigNumber(data) + value.(*values.BigIntValue).Value = n + case *values.AddressValue: + return c.decodeTopLevelAddress(data, value.(*values.AddressValue)) + case *values.StructValue: + return c.decodeTopLevelStruct(data, value.(*values.StructValue)) + case *values.EnumValue: + return c.decodeTopLevelEnum(data, value.(*values.EnumValue)) + default: + return fmt.Errorf("unsupported type for top-level decoding: %T", value) + } + + return nil +} diff --git a/abi/codec/defaultCodecForCompositeTypes.go b/abi/codec/defaultCodecForCompositeTypes.go new file mode 100644 index 00000000..19d0d3ab --- /dev/null +++ b/abi/codec/defaultCodecForCompositeTypes.go @@ -0,0 +1,78 @@ +package codec + +import ( + "errors" + "io" + + "github.com/multiversx/mx-sdk-go/abi/values" +) + +func (c *defaultCodec) encodeNestedOption(writer io.Writer, value values.OptionValue) error { + if value.Value == nil { + _, err := writer.Write([]byte{0}) + return err + } + + _, err := writer.Write([]byte{1}) + if err != nil { + return err + } + + return c.doEncodeNested(writer, value.Value) +} + +func (c *defaultCodec) decodeNestedOption(reader io.Reader, value *values.OptionValue) error { + bytes, err := readBytesExactly(reader, 1) + if err != nil { + return err + } + + if bytes[0] == 0 { + value.Value = nil + return nil + } + + return c.doDecodeNested(reader, value.Value) +} + +func (c *defaultCodec) encodeNestedList(writer io.Writer, value values.InputListValue) error { + err := c.encodeLength(writer, uint32(len(value.Items))) + if err != nil { + return err + } + + for _, item := range value.Items { + err := c.doEncodeNested(writer, item) + if err != nil { + return err + } + } + + return nil +} + +func (c *defaultCodec) decodeNestedList(reader io.Reader, value *values.OutputListValue) error { + if value.ItemCreator == nil { + return errors.New("cannot deserialize list: item creator is nil") + } + + length, err := c.decodeLength(reader) + if err != nil { + return err + } + + value.Items = make([]any, 0, length) + + for i := uint32(0); i < length; i++ { + newItem := value.ItemCreator() + + err := c.doDecodeNested(reader, newItem) + if err != nil { + return err + } + + value.Items = append(value.Items, newItem) + } + + return nil +} diff --git a/abi/codec/defaultCodecForCustomTypes.go b/abi/codec/defaultCodecForCustomTypes.go new file mode 100644 index 00000000..a531450d --- /dev/null +++ b/abi/codec/defaultCodecForCustomTypes.go @@ -0,0 +1,95 @@ +package codec + +import ( + "bytes" + "fmt" + "io" + + "github.com/multiversx/mx-sdk-go/abi/values" +) + +// https://docs.multiversx.com/developers/data/custom-types +func (c *defaultCodec) encodeNestedStruct(writer io.Writer, value values.StructValue) error { + for _, field := range value.Fields { + err := c.doEncodeNested(writer, field.Value) + if err != nil { + return fmt.Errorf("cannot encode field '%s' of struct, because of: %w", field.Name, err) + } + } + + return nil +} + +func (c *defaultCodec) encodeTopLevelStruct(writer io.Writer, value values.StructValue) error { + return c.encodeNestedStruct(writer, value) +} + +func (c *defaultCodec) decodeNestedStruct(reader io.Reader, value *values.StructValue) error { + for _, field := range value.Fields { + err := c.doDecodeNested(reader, field.Value) + if err != nil { + return fmt.Errorf("cannot decode field '%s' of struct, because of: %w", field.Name, err) + } + } + + return nil +} + +func (c *defaultCodec) decodeTopLevelStruct(data []byte, value *values.StructValue) error { + reader := bytes.NewReader(data) + return c.decodeNestedStruct(reader, value) +} + +func (c *defaultCodec) encodeNestedEnum(writer io.Writer, value values.EnumValue) error { + err := c.doEncodeNested(writer, values.U8Value{Value: value.Discriminant}) + if err != nil { + return err + } + + for _, field := range value.Fields { + err := c.doEncodeNested(writer, field.Value) + if err != nil { + return fmt.Errorf("cannot encode field '%s' of enum, because of: %w", field.Name, err) + } + } + + return nil +} + +func (c *defaultCodec) encodeTopLevelEnum(writer io.Writer, value values.EnumValue) error { + if value.Discriminant == 0 && len(value.Fields) == 0 { + // Write nothing + return nil + } + + return c.encodeNestedEnum(writer, value) +} + +func (c *defaultCodec) decodeNestedEnum(reader io.Reader, value *values.EnumValue) error { + discriminant := &values.U8Value{} + err := c.doDecodeNested(reader, discriminant) + if err != nil { + return err + } + + value.Discriminant = discriminant.Value + + for _, field := range value.Fields { + err := c.doDecodeNested(reader, field.Value) + if err != nil { + return fmt.Errorf("cannot decode field '%s' of enum, because of: %w", field.Name, err) + } + } + + return nil +} + +func (c *defaultCodec) decodeTopLevelEnum(data []byte, value *values.EnumValue) error { + if len(data) == 0 { + value.Discriminant = 0 + return nil + } + + reader := bytes.NewReader(data) + return c.decodeNestedEnum(reader, value) +} diff --git a/abi/codec/defaultCodecForSimpleValues.go b/abi/codec/defaultCodecForSimpleValues.go new file mode 100644 index 00000000..7cdd6859 --- /dev/null +++ b/abi/codec/defaultCodecForSimpleValues.go @@ -0,0 +1,309 @@ +package codec + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/big" + + twos "github.com/multiversx/mx-components-big-int/twos-complement" + "github.com/multiversx/mx-sdk-go/abi/values" +) + +func (c *defaultCodec) encodeNestedBool(writer io.Writer, value values.BoolValue) error { + if value.Value { + _, err := writer.Write([]byte{trueAsByte}) + return err + } + + _, err := writer.Write([]byte{falseAsByte}) + return err +} + +func (c *defaultCodec) decodeNestedBool(reader io.Reader, value *values.BoolValue) error { + data, err := readBytesExactly(reader, 1) + if err != nil { + return err + } + + value.Value, err = c.byteToBool(data[0]) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) encodeTopLevelBool(writer io.Writer, value values.BoolValue) error { + if !value.Value { + // For "false", write nothing. + return nil + } + + _, err := writer.Write([]byte{trueAsByte}) + return err +} + +func (c *defaultCodec) decodeTopLevelBool(data []byte, value *values.BoolValue) error { + if len(data) == 0 { + value.Value = false + return nil + } + + if len(data) == 1 { + boolValue, err := c.byteToBool(data[0]) + if err != nil { + return err + } + + value.Value = boolValue + return nil + } + + return fmt.Errorf("unexpected boolean value: %v", data) +} + +func (c *defaultCodec) byteToBool(data uint8) (bool, error) { + switch data { + case trueAsByte: + return true, nil + case falseAsByte: + return false, nil + default: + return false, fmt.Errorf("unexpected boolean value: %d", data) + } +} + +func (c *defaultCodec) encodeNestedNumber(writer io.Writer, value any, numBytes int) error { + buffer := new(bytes.Buffer) + + err := binary.Write(buffer, binary.BigEndian, value) + if err != nil { + return err + } + + data := buffer.Bytes() + if len(data) != numBytes { + return fmt.Errorf("unexpected number of bytes: %d != %d", len(data), numBytes) + } + + _, err = writer.Write(data) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) decodeNestedNumber(reader io.Reader, value any, numBytes int) error { + data, err := readBytesExactly(reader, numBytes) + if err != nil { + return err + } + + buffer := bytes.NewReader(data) + err = binary.Read(buffer, binary.BigEndian, value) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) encodeTopLevelUnsignedNumber(writer io.Writer, value uint64) error { + b := big.NewInt(0).SetUint64(value) + data := b.Bytes() + _, err := writer.Write(data) + return err +} + +func (c *defaultCodec) encodeTopLevelSignedNumber(writer io.Writer, value int64) error { + b := big.NewInt(0).SetInt64(value) + data := b.Bytes() + _, err := writer.Write(data) + return err +} + +func (c *defaultCodec) decodeTopLevelUnsignedNumber(data []byte, maxValue uint64) (uint64, error) { + b := big.NewInt(0).SetBytes(data) + if !b.IsUint64() { + return 0, fmt.Errorf("decoded value is too large (does not fit an uint64): %s", b) + } + + n := b.Uint64() + if n > maxValue { + return 0, fmt.Errorf("decoded value is too large: %d > %d", n, maxValue) + } + + return n, nil +} + +func (c *defaultCodec) decodeTopLevelSignedNumber(data []byte, maxValue int64) (int64, error) { + b := big.NewInt(0).SetBytes(data) + if !b.IsInt64() { + return 0, fmt.Errorf("decoded value is too large (does not fit an int64): %s", b) + } + + n := b.Int64() + if n > maxValue { + return 0, fmt.Errorf("decoded value is too large: %d > %d", n, maxValue) + } + + return n, nil +} + +func (c *defaultCodec) encodeNestedBigNumber(writer io.Writer, value *big.Int) error { + data := twos.ToBytes(value) + dataLength := len(data) + + // Write the length of the payload + err := c.encodeLength(writer, uint32(dataLength)) + if err != nil { + return err + } + + // Write the payload + _, err = writer.Write(data) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) encodeTopLevelBigNumber(writer io.Writer, value *big.Int) error { + data := twos.ToBytes(value) + _, err := writer.Write(data) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) decodeNestedBigNumber(reader io.Reader) (*big.Int, error) { + // Read the length of the payload + length, err := c.decodeLength(reader) + if err != nil { + return nil, err + } + + // Read the payload + data, err := readBytesExactly(reader, int(length)) + if err != nil { + return nil, err + } + + return twos.FromBytes(data), nil +} + +func (c *defaultCodec) decodeTopLevelBigNumber(data []byte) *big.Int { + return twos.FromBytes(data) +} + +func (c *defaultCodec) encodeLength(writer io.Writer, length uint32) error { + bytes := make([]byte, 4) + binary.BigEndian.PutUint32(bytes, length) + + _, err := writer.Write(bytes) + if err != nil { + return err + } + + return nil +} + +func (c *defaultCodec) decodeLength(reader io.Reader) (uint32, error) { + bytes, err := readBytesExactly(reader, 4) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(bytes), nil +} + +func (c *defaultCodec) encodeNestedString(writer io.Writer, value values.StringValue) error { + data := []byte(value.Value) + err := c.encodeLength(writer, uint32(len(data))) + if err != nil { + return err + } + + _, err = writer.Write(data) + return err +} + +func (c *defaultCodec) decodeNestedString(reader io.Reader, value *values.StringValue) error { + length, err := c.decodeLength(reader) + if err != nil { + return err + } + + data, err := readBytesExactly(reader, int(length)) + if err != nil { + return err + } + + value.Value = string(data) + return nil +} + +func (c *defaultCodec) encodeNestedBytes(writer io.Writer, value values.BytesValue) error { + err := c.encodeLength(writer, uint32(len(value.Value))) + if err != nil { + return err + } + + _, err = writer.Write(value.Value) + return err +} + +func (c *defaultCodec) decodeNestedBytes(reader io.Reader, value *values.BytesValue) error { + length, err := c.decodeLength(reader) + if err != nil { + return err + } + + data, err := readBytesExactly(reader, int(length)) + if err != nil { + return err + } + + value.Value = data + return nil +} + +func (c *defaultCodec) encodeNestedAddress(writer io.Writer, value values.AddressValue) error { + return c.encodeTopLevelAddress(writer, value) +} + +func (c *defaultCodec) encodeTopLevelAddress(writer io.Writer, value values.AddressValue) error { + err := checkPubKeyLength(value.Value) + if err != nil { + return err + } + + _, err = writer.Write(value.Value) + return err +} + +func (c *defaultCodec) decodeNestedAddress(reader io.Reader, value *values.AddressValue) error { + data, err := readBytesExactly(reader, pubKeyLength) + if err != nil { + return err + } + + value.Value = data + return nil +} + +func (c *defaultCodec) decodeTopLevelAddress(data []byte, value *values.AddressValue) error { + err := checkPubKeyLength(data) + if err != nil { + return err + } + + value.Value = data + return nil +} diff --git a/abi/codec/defaultCodec_test.go b/abi/codec/defaultCodec_test.go new file mode 100644 index 00000000..0fbcf435 --- /dev/null +++ b/abi/codec/defaultCodec_test.go @@ -0,0 +1,726 @@ +package codec + +import ( + "encoding/hex" + "math/big" + "testing" + + "github.com/multiversx/mx-sdk-go/abi/values" + "github.com/stretchr/testify/require" +) + +func TestCodec_EncodeNested(t *testing.T) { + codec := NewDefaultCodec() + + doTest := func(t *testing.T, value any, expected string) { + encoded, err := codec.EncodeNested(value) + require.NoError(t, err) + require.Equal(t, expected, hex.EncodeToString(encoded)) + } + + t.Run("bool", func(t *testing.T) { + doTest(t, values.BoolValue{Value: false}, "00") + doTest(t, values.BoolValue{Value: true}, "01") + }) + + t.Run("u8, i8", func(t *testing.T) { + doTest(t, values.U8Value{Value: 0x00}, "00") + doTest(t, values.U8Value{Value: 0x01}, "01") + doTest(t, values.U8Value{Value: 0x42}, "42") + doTest(t, values.U8Value{Value: 0xff}, "ff") + + doTest(t, values.I8Value{Value: 0x00}, "00") + doTest(t, values.I8Value{Value: 0x01}, "01") + doTest(t, values.I8Value{Value: -1}, "ff") + doTest(t, values.I8Value{Value: -128}, "80") + doTest(t, values.I8Value{Value: 127}, "7f") + }) + + t.Run("u16", func(t *testing.T) { + doTest(t, values.U16Value{Value: 0x00}, "0000") + doTest(t, values.U16Value{Value: 0x11}, "0011") + doTest(t, values.U16Value{Value: 0x1234}, "1234") + doTest(t, values.U16Value{Value: 0xffff}, "ffff") + }) + + t.Run("u32", func(t *testing.T) { + doTest(t, values.U32Value{Value: 0x00000000}, "00000000") + doTest(t, values.U32Value{Value: 0x00000011}, "00000011") + doTest(t, values.U32Value{Value: 0x00001122}, "00001122") + doTest(t, values.U32Value{Value: 0x00112233}, "00112233") + doTest(t, values.U32Value{Value: 0x11223344}, "11223344") + doTest(t, values.U32Value{Value: 0xffffffff}, "ffffffff") + }) + + t.Run("u64", func(t *testing.T) { + doTest(t, values.U64Value{Value: 0x0000000000000000}, "0000000000000000") + doTest(t, values.U64Value{Value: 0x0000000000000011}, "0000000000000011") + doTest(t, values.U64Value{Value: 0x0000000000001122}, "0000000000001122") + doTest(t, values.U64Value{Value: 0x0000000000112233}, "0000000000112233") + doTest(t, values.U64Value{Value: 0x0000000011223344}, "0000000011223344") + doTest(t, values.U64Value{Value: 0x0000001122334455}, "0000001122334455") + doTest(t, values.U64Value{Value: 0x0000112233445566}, "0000112233445566") + doTest(t, values.U64Value{Value: 0x0011223344556677}, "0011223344556677") + doTest(t, values.U64Value{Value: 0x1122334455667788}, "1122334455667788") + doTest(t, values.U64Value{Value: 0xffffffffffffffff}, "ffffffffffffffff") + }) + + t.Run("bigInt", func(t *testing.T) { + doTest(t, values.BigIntValue{Value: big.NewInt(0)}, "00000000") + doTest(t, values.BigIntValue{Value: big.NewInt(1)}, "0000000101") + doTest(t, values.BigIntValue{Value: big.NewInt(-1)}, "00000001ff") + }) + + t.Run("address", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d421f24c29181e63888228dc81ca60d69e1") + doTest(t, values.AddressValue{Value: data}, "0139472eff6886771a982f3083da5d421f24c29181e63888228dc81ca60d69e1") + }) + + t.Run("address (bad)", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d42") + _, err := codec.EncodeNested(values.AddressValue{Value: data}) + require.ErrorContains(t, err, "public key (address) has invalid length") + }) + + t.Run("string", func(t *testing.T) { + doTest(t, values.StringValue{Value: ""}, "00000000") + doTest(t, values.StringValue{Value: "abc"}, "00000003616263") + }) + + t.Run("bytes", func(t *testing.T) { + doTest(t, values.BytesValue{Value: []byte{}}, "00000000") + doTest(t, values.BytesValue{Value: []byte{'a', 'b', 'c'}}, "00000003616263") + }) + + t.Run("struct", func(t *testing.T) { + fooStruct := values.StructValue{ + Fields: []values.Field{ + { + Value: values.U8Value{Value: 0x01}, + }, + { + Value: values.U16Value{Value: 0x4142}, + }, + }, + } + + doTest(t, fooStruct, "014142") + }) + + t.Run("enum (discriminant == 0)", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 0, + } + + doTest(t, fooEnum, "00") + }) + + t.Run("enum (discriminant != 0)", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 42, + } + + doTest(t, fooEnum, "2a") + }) + + t.Run("enum with values.Fields", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 42, + Fields: []values.Field{ + { + Value: values.U8Value{Value: 0x01}, + }, + { + Value: values.U16Value{Value: 0x4142}, + }, + }, + } + + doTest(t, fooEnum, "2a014142") + }) + + t.Run("option with value", func(t *testing.T) { + fooOption := values.OptionValue{ + Value: values.U16Value{Value: 0x08}, + } + + doTest(t, fooOption, "010008") + }) + + t.Run("option without value", func(t *testing.T) { + fooOption := values.OptionValue{ + Value: nil, + } + + doTest(t, fooOption, "00") + }) + + t.Run("list", func(t *testing.T) { + fooList := values.InputListValue{ + Items: []any{ + values.U16Value{Value: 1}, + values.U16Value{Value: 2}, + values.U16Value{Value: 3}, + }, + } + + doTest(t, fooList, "00000003000100020003") + }) +} + +func TestCodec_EncodeTopLevel(t *testing.T) { + codec := NewDefaultCodec() + + doTest := func(t *testing.T, value any, expected string) { + encoded, err := codec.EncodeTopLevel(value) + require.NoError(t, err) + require.Equal(t, expected, hex.EncodeToString(encoded)) + } + + t.Run("bool", func(t *testing.T) { + doTest(t, values.BoolValue{Value: false}, "") + doTest(t, values.BoolValue{Value: true}, "01") + }) + + t.Run("u8", func(t *testing.T) { + doTest(t, values.U8Value{Value: 0x00}, "") + doTest(t, values.U8Value{Value: 0x01}, "01") + }) + + t.Run("u16", func(t *testing.T) { + doTest(t, values.U16Value{Value: 0x0042}, "42") + }) + + t.Run("u32", func(t *testing.T) { + doTest(t, values.U32Value{Value: 0x00004242}, "4242") + }) + + t.Run("u64", func(t *testing.T) { + doTest(t, values.U64Value{Value: 0x0042434445464748}, "42434445464748") + }) + + t.Run("bigInt", func(t *testing.T) { + doTest(t, values.BigIntValue{Value: big.NewInt(0)}, "") + doTest(t, values.BigIntValue{Value: big.NewInt(1)}, "01") + doTest(t, values.BigIntValue{Value: big.NewInt(-1)}, "ff") + }) + + t.Run("address", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d421f24c29181e63888228dc81ca60d69e1") + doTest(t, values.AddressValue{Value: data}, "0139472eff6886771a982f3083da5d421f24c29181e63888228dc81ca60d69e1") + }) + + t.Run("address (bad)", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d42") + _, err := codec.EncodeTopLevel(values.AddressValue{Value: data}) + require.ErrorContains(t, err, "public key (address) has invalid length") + }) + + t.Run("struct", func(t *testing.T) { + fooStruct := values.StructValue{ + Fields: []values.Field{ + { + Value: values.U8Value{Value: 0x01}, + }, + { + Value: values.U16Value{Value: 0x4142}, + }, + }, + } + + doTest(t, fooStruct, "014142") + }) + + t.Run("enum (discriminant == 0)", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 0, + } + + doTest(t, fooEnum, "") + }) + + t.Run("enum (discriminant != 0)", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 42, + } + + doTest(t, fooEnum, "2a") + }) + + t.Run("enum with values.Fields", func(t *testing.T) { + fooEnum := values.EnumValue{ + Discriminant: 42, + Fields: []values.Field{ + { + Value: values.U8Value{Value: 0x01}, + }, + { + Value: values.U16Value{Value: 0x4142}, + }, + }, + } + + doTest(t, fooEnum, "2a014142") + }) +} + +func TestCodec_DecodeNested(t *testing.T) { + codec := NewDefaultCodec() + + t.Run("bool (true)", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.BoolValue{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BoolValue{Value: true}, destination) + }) + + t.Run("bool (false)", func(t *testing.T) { + data, _ := hex.DecodeString("00") + destination := &values.BoolValue{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BoolValue{Value: false}, destination) + }) + + t.Run("u8", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.U8Value{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U8Value{Value: 0x01}, destination) + }) + + t.Run("u16", func(t *testing.T) { + data, _ := hex.DecodeString("4142") + destination := &values.U16Value{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U16Value{Value: 0x4142}, destination) + }) + + t.Run("u32", func(t *testing.T) { + data, _ := hex.DecodeString("41424344") + destination := &values.U32Value{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U32Value{Value: 0x41424344}, destination) + }) + + t.Run("u64", func(t *testing.T) { + data, _ := hex.DecodeString("4142434445464748") + destination := &values.U64Value{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U64Value{Value: 0x4142434445464748}, destination) + }) + + t.Run("u16, should err because it cannot read 2 bytes", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.U16Value{} + + err := codec.DecodeNested(data, destination) + require.ErrorContains(t, err, "cannot read exactly 2 bytes") + }) + + t.Run("u32, should err because it cannot read 4 bytes", func(t *testing.T) { + data, _ := hex.DecodeString("4142") + destination := &values.U32Value{} + + err := codec.DecodeNested(data, destination) + require.ErrorContains(t, err, "cannot read exactly 4 bytes") + }) + + t.Run("u64, should err because it cannot read 8 bytes", func(t *testing.T) { + data, _ := hex.DecodeString("41424344") + destination := &values.U64Value{} + + err := codec.DecodeNested(data, destination) + require.ErrorContains(t, err, "cannot read exactly 8 bytes") + }) + + t.Run("bigInt", func(t *testing.T) { + data, _ := hex.DecodeString("00000000") + destination := &values.BigIntValue{} + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(0)}, destination) + + data, _ = hex.DecodeString("0000000101") + destination = &values.BigIntValue{} + err = codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(1)}, destination) + + data, _ = hex.DecodeString("00000001ff") + destination = &values.BigIntValue{} + err = codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(-1)}, destination) + }) + + t.Run("address", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d421f24c29181e63888228dc81ca60d69e1") + + destination := &values.AddressValue{} + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.AddressValue{Value: data}, destination) + }) + + t.Run("address (bad)", func(t *testing.T) { + data, _ := hex.DecodeString("0139472eff6886771a982f3083da5d42") + + destination := &values.AddressValue{} + err := codec.DecodeNested(data, destination) + require.ErrorContains(t, err, "cannot read exactly 32 bytes") + }) + + t.Run("string", func(t *testing.T) { + data, _ := hex.DecodeString("00000000") + destination := &values.StringValue{} + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.StringValue{}, destination) + + data, _ = hex.DecodeString("00000003616263") + destination = &values.StringValue{} + err = codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.StringValue{Value: "abc"}, destination) + }) + + t.Run("bytes", func(t *testing.T) { + data, _ := hex.DecodeString("00000000") + destination := &values.BytesValue{} + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BytesValue{Value: []byte{}}, destination) + + data, _ = hex.DecodeString("00000003616263") + destination = &values.BytesValue{} + err = codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BytesValue{Value: []byte{'a', 'b', 'c'}}, destination) + }) + + t.Run("struct", func(t *testing.T) { + data, _ := hex.DecodeString("014142") + + destination := &values.StructValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{}, + }, + { + Value: &values.U16Value{}, + }, + }, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.StructValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{Value: 0x01}, + }, + { + Value: &values.U16Value{Value: 0x4142}, + }, + }, + }, destination) + }) + + t.Run("enum (discriminant == 0)", func(t *testing.T) { + data, _ := hex.DecodeString("00") + destination := &values.EnumValue{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x00, + }, destination) + }) + + t.Run("enum (discriminant != 0)", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.EnumValue{} + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x01, + }, destination) + }) + + t.Run("enum with values.Fields", func(t *testing.T) { + data, _ := hex.DecodeString("01014142") + + destination := &values.EnumValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{}, + }, + { + Value: &values.U16Value{}, + }, + }, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x01, + Fields: []values.Field{ + { + Value: &values.U8Value{Value: 0x01}, + }, + { + Value: &values.U16Value{Value: 0x4142}, + }, + }, + }, destination) + }) + + t.Run("option with value", func(t *testing.T) { + data, _ := hex.DecodeString("010008") + + destination := &values.OptionValue{ + Value: &values.U16Value{}, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.OptionValue{ + Value: &values.U16Value{Value: 8}, + }, destination) + }) + + t.Run("option without value", func(t *testing.T) { + data, _ := hex.DecodeString("00") + + destination := &values.OptionValue{ + Value: &values.U16Value{}, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, &values.OptionValue{ + Value: nil, + }, destination) + }) + + t.Run("list", func(t *testing.T) { + data, _ := hex.DecodeString("00000003000100020003") + + destination := &values.OutputListValue{ + ItemCreator: func() any { return &values.U16Value{} }, + Items: []any{}, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, + []any{ + &values.U16Value{Value: 1}, + &values.U16Value{Value: 2}, + &values.U16Value{Value: 3}, + }, destination.Items) + }) +} + +func TestCodec_DecodeTopLevel(t *testing.T) { + codec := NewDefaultCodec() + + t.Run("bool (true)", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.BoolValue{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BoolValue{Value: true}, destination) + }) + + t.Run("bool (false)", func(t *testing.T) { + data, _ := hex.DecodeString("") + destination := &values.BoolValue{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BoolValue{Value: false}, destination) + }) + + t.Run("u8", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.U8Value{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U8Value{Value: 0x01}, destination) + }) + + t.Run("u16", func(t *testing.T) { + data, _ := hex.DecodeString("02") + destination := &values.U16Value{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U16Value{Value: 0x0002}, destination) + }) + + t.Run("u32", func(t *testing.T) { + data, _ := hex.DecodeString("03") + destination := &values.U32Value{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U32Value{Value: 0x00000003}, destination) + }) + + t.Run("u64", func(t *testing.T) { + data, _ := hex.DecodeString("04") + destination := &values.U64Value{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.U64Value{Value: 0x0000000000000004}, destination) + }) + + t.Run("u8, should err because decoded value is too large", func(t *testing.T) { + data, _ := hex.DecodeString("4142") + destination := &values.U8Value{} + + err := codec.DecodeTopLevel(data, destination) + require.ErrorContains(t, err, "decoded value is too large") + }) + + t.Run("u16, should err because decoded value is too large", func(t *testing.T) { + data, _ := hex.DecodeString("41424344") + destination := &values.U16Value{} + + err := codec.DecodeTopLevel(data, destination) + require.ErrorContains(t, err, "decoded value is too large") + }) + + t.Run("u32, should err because decoded value is too large", func(t *testing.T) { + data, _ := hex.DecodeString("4142434445464748") + destination := &values.U32Value{} + + err := codec.DecodeTopLevel(data, destination) + require.ErrorContains(t, err, "decoded value is too large") + }) + + t.Run("u64, should err because decoded value is too large", func(t *testing.T) { + data, _ := hex.DecodeString("41424344454647489876") + destination := &values.U64Value{} + + err := codec.DecodeTopLevel(data, destination) + require.ErrorContains(t, err, "decoded value is too large") + }) + + t.Run("bigInt", func(t *testing.T) { + data, _ := hex.DecodeString("") + destination := &values.BigIntValue{} + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(0)}, destination) + + data, _ = hex.DecodeString("01") + destination = &values.BigIntValue{} + err = codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(1)}, destination) + + data, _ = hex.DecodeString("ff") + destination = &values.BigIntValue{} + err = codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.BigIntValue{Value: big.NewInt(-1)}, destination) + }) + + t.Run("struct", func(t *testing.T) { + data, _ := hex.DecodeString("014142") + + destination := &values.StructValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{}, + }, + { + Value: &values.U16Value{}, + }, + }, + } + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.StructValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{Value: 0x01}, + }, + { + Value: &values.U16Value{Value: 0x4142}, + }, + }, + }, destination) + }) + + t.Run("enum (discriminant == 0)", func(t *testing.T) { + data, _ := hex.DecodeString("") + destination := &values.EnumValue{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x00, + }, destination) + }) + + t.Run("enum (discriminant != 0)", func(t *testing.T) { + data, _ := hex.DecodeString("01") + destination := &values.EnumValue{} + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x01, + }, destination) + }) + + t.Run("enum with values.Fields", func(t *testing.T) { + data, _ := hex.DecodeString("01014142") + + destination := &values.EnumValue{ + Fields: []values.Field{ + { + Value: &values.U8Value{}, + }, + { + Value: &values.U16Value{}, + }, + }, + } + + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, &values.EnumValue{ + Discriminant: 0x01, + Fields: []values.Field{ + { + Value: &values.U8Value{Value: 0x01}, + }, + { + Value: &values.U16Value{Value: 0x4142}, + }, + }, + }, destination) + }) +} diff --git a/abi/codec/shared.go b/abi/codec/shared.go new file mode 100644 index 00000000..de31eeb7 --- /dev/null +++ b/abi/codec/shared.go @@ -0,0 +1,31 @@ +package codec + +import ( + "fmt" + "io" +) + +func readBytesExactly(reader io.Reader, numBytes int) ([]byte, error) { + if numBytes == 0 { + return []byte{}, nil + } + + data := make([]byte, numBytes) + n, err := reader.Read(data) + if err != nil { + return nil, err + } + if n != numBytes { + return nil, fmt.Errorf("cannot read exactly %d bytes", numBytes) + } + + return data, err +} + +func checkPubKeyLength(pubkey []byte) error { + if len(pubkey) != pubKeyLength { + return fmt.Errorf("public key (address) has invalid length: %d", len(pubkey)) + } + + return nil +} diff --git a/abi/go.mod b/abi/go.mod new file mode 100644 index 00000000..66f0b7bc --- /dev/null +++ b/abi/go.mod @@ -0,0 +1,14 @@ +module github.com/multiversx/mx-sdk-go/abi + +go 1.20 + +require ( + github.com/multiversx/mx-components-big-int v1.0.0 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/abi/go.sum b/abi/go.sum new file mode 100644 index 00000000..ddf741c1 --- /dev/null +++ b/abi/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/multiversx/mx-components-big-int v1.0.0 h1:Wkr8lSzK2nDqixOrrBa47VNuqdhV1m/aJhaP1EMaiS8= +github.com/multiversx/mx-components-big-int v1.0.0/go.mod h1:maIEMgHlNE2u78JaDD0oLzri+ShgU4okHfzP3LWGdQM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/abi/serializer/constants.go b/abi/serializer/constants.go new file mode 100644 index 00000000..c307b212 --- /dev/null +++ b/abi/serializer/constants.go @@ -0,0 +1,3 @@ +package serializer + +const partsSeparator = "@" diff --git a/abi/serializer/interface.go b/abi/serializer/interface.go new file mode 100644 index 00000000..0faeab4a --- /dev/null +++ b/abi/serializer/interface.go @@ -0,0 +1,8 @@ +package serializer + +type valuesCodec interface { + EncodeNested(value any) ([]byte, error) + EncodeTopLevel(value any) ([]byte, error) + DecodeNested(data []byte, value any) error + DecodeTopLevel(data []byte, value any) error +} diff --git a/abi/serializer/parts.go b/abi/serializer/parts.go new file mode 100644 index 00000000..87d2d981 --- /dev/null +++ b/abi/serializer/parts.go @@ -0,0 +1,99 @@ +package serializer + +import ( + "errors" + "fmt" +) + +// partsHolder holds data parts (e.g. raw contract call arguments, raw contract return values). +// It allows one to easily construct parts (thus functioning as a builder of parts). +// It also allows one to focus on a specific part to read from (thus functioning as a reader of parts: think of a pick-up head). +// Both functionalities (building and reading) are kept within this single abstraction, for convenience. +type partsHolder struct { + parts [][]byte + focusedPartIndex uint32 +} + +// newPartsHolder creates a new partsHolder, which has the given parts. +// Focus is on the first part, if any, or "beyond the last part" otherwise. +func newPartsHolder(parts [][]byte) *partsHolder { + return &partsHolder{ + parts: parts, + focusedPartIndex: 0, + } +} + +// newEmptyPartsHolder creates a new partsHolder, which has no parts. +// Parts are created by calling appendEmptyPart(). +// Focus is "beyond the last part" (since there is no part). +func newEmptyPartsHolder() *partsHolder { + return &partsHolder{ + parts: [][]byte{}, + focusedPartIndex: 0, + } +} + +func (holder *partsHolder) getParts() [][]byte { + return holder.parts +} + +func (holder *partsHolder) getNumParts() uint32 { + return uint32(len(holder.parts)) +} + +func (holder *partsHolder) getPart(index uint32) ([]byte, error) { + if index >= holder.getNumParts() { + return nil, fmt.Errorf("part index %d is out of range", index) + } + + return holder.parts[index], nil +} + +func (holder *partsHolder) appendToLastPart(data []byte) error { + if !holder.hasAnyPart() { + return errors.New("cannot write, since there is no part to write to") + } + + holder.parts[len(holder.parts)-1] = append(holder.parts[len(holder.parts)-1], data...) + return nil +} + +func (holder *partsHolder) hasAnyPart() bool { + return len(holder.parts) > 0 +} + +func (holder *partsHolder) appendEmptyPart() { + holder.parts = append(holder.parts, []byte{}) +} + +// readWholeFocusedPart reads the whole focused part, if any. Otherwise, it returns an error. +func (holder *partsHolder) readWholeFocusedPart() ([]byte, error) { + if holder.isFocusedBeyondLastPart() { + return nil, fmt.Errorf("cannot wholly read part %d: unexpected end of data", holder.focusedPartIndex) + } + + part, err := holder.getPart(uint32(holder.focusedPartIndex)) + if err != nil { + return nil, err + } + + return part, nil +} + +// focusOnNextPart focuses on the next part, if any. Otherwise, it returns an error. +func (holder *partsHolder) focusOnNextPart() error { + if holder.isFocusedBeyondLastPart() { + return fmt.Errorf( + "cannot focus on next part, since the focus is already beyond the last part; focused part index is %d", + holder.focusedPartIndex, + ) + } + + holder.focusedPartIndex++ + return nil +} + +// isFocusedBeyondLastPart returns true if the focus is already beyond the last part. +func (holder *partsHolder) isFocusedBeyondLastPart() bool { + return holder.focusedPartIndex >= holder.getNumParts() +} diff --git a/abi/serializer/serializer.go b/abi/serializer/serializer.go new file mode 100644 index 00000000..4600faa4 --- /dev/null +++ b/abi/serializer/serializer.go @@ -0,0 +1,224 @@ +package serializer + +import ( + "encoding/hex" + "errors" + "strings" + + "github.com/multiversx/mx-sdk-go/abi/values" +) + +type serializer struct { + codec valuesCodec +} + +func NewSerializer(codec valuesCodec) *serializer { + return &serializer{ + codec: codec, + } +} + +func (s *serializer) Serialize(inputValues []any) (string, error) { + parts, err := s.SerializeToParts(inputValues) + if err != nil { + return "", err + } + + return s.encodeParts(parts), nil +} + +func (s *serializer) SerializeToParts(inputValues []any) ([][]byte, error) { + partsHolder := newEmptyPartsHolder() + + err := s.doSerialize(partsHolder, inputValues) + if err != nil { + return nil, err + } + + return partsHolder.getParts(), nil +} + +func (s *serializer) doSerialize(partsHolder *partsHolder, inputValues []any) error { + var err error + + for i, value := range inputValues { + if value == nil { + return errors.New("cannot serialize nil value") + } + + switch value.(type) { + case values.InputMultiValue: + err = s.serializeInputMultiValue(partsHolder, value.(values.InputMultiValue)) + case values.InputVariadicValues: + if i != len(inputValues)-1 { + return errors.New("variadic values must be last among input values") + } + + err = s.serializeInputVariadicValues(partsHolder, value.(values.InputVariadicValues)) + default: + partsHolder.appendEmptyPart() + err = s.serializeDirectlyEncodableValue(partsHolder, value) + } + + if err != nil { + return err + } + } + + return nil +} + +func (s *serializer) Deserialize(data string, outputValues []any) error { + parts, err := s.decodeIntoParts(data) + if err != nil { + return err + } + + return s.DeserializeParts(parts, outputValues) +} + +func (s *serializer) DeserializeParts(parts [][]byte, outputValues []any) error { + partsHolder := newPartsHolder(parts) + + err := s.doDeserialize(partsHolder, outputValues) + if err != nil { + return err + } + + return nil +} + +func (s *serializer) doDeserialize(partsHolder *partsHolder, outputValues []any) error { + var err error + + for i, value := range outputValues { + if value == nil { + return errors.New("cannot deserialize into nil value") + } + + switch value.(type) { + case *values.OutputMultiValue: + err = s.deserializeOutputMultiValue(partsHolder, value.(*values.OutputMultiValue)) + case *values.OutputVariadicValues: + if i != len(outputValues)-1 { + return errors.New("variadic values must be last among output values") + } + + err = s.deserializeOutputVariadicValues(partsHolder, value.(*values.OutputVariadicValues)) + default: + err = s.deserializeDirectlyEncodableValue(partsHolder, value) + } + + if err != nil { + return err + } + } + + return nil +} + +func (s *serializer) serializeInputMultiValue(partsHolder *partsHolder, value values.InputMultiValue) error { + for _, item := range value.Items { + err := s.doSerialize(partsHolder, []any{item}) + if err != nil { + return err + } + } + + return nil +} + +func (s *serializer) serializeInputVariadicValues(partsHolder *partsHolder, value values.InputVariadicValues) error { + for _, item := range value.Items { + err := s.doSerialize(partsHolder, []any{item}) + if err != nil { + return err + } + } + + return nil +} + +func (s *serializer) serializeDirectlyEncodableValue(partsHolder *partsHolder, value any) error { + data, err := s.codec.EncodeTopLevel(value) + if err != nil { + return err + } + + return partsHolder.appendToLastPart(data) +} + +func (s *serializer) deserializeOutputMultiValue(partsHolder *partsHolder, value *values.OutputMultiValue) error { + for _, item := range value.Items { + err := s.doDeserialize(partsHolder, []any{item}) + if err != nil { + return err + } + } + + return nil +} + +func (s *serializer) deserializeOutputVariadicValues(partsHolder *partsHolder, value *values.OutputVariadicValues) error { + if value.ItemCreator == nil { + return errors.New("cannot deserialize variadic values: item creator is nil") + } + + for !partsHolder.isFocusedBeyondLastPart() { + newItem := value.ItemCreator() + + err := s.doDeserialize(partsHolder, []any{newItem}) + if err != nil { + return err + } + + value.Items = append(value.Items, newItem) + } + + return nil +} + +func (s *serializer) deserializeDirectlyEncodableValue(partsHolder *partsHolder, value any) error { + part, err := partsHolder.readWholeFocusedPart() + if err != nil { + return err + } + + err = s.codec.DecodeTopLevel(part, value) + if err != nil { + return err + } + + err = partsHolder.focusOnNextPart() + if err != nil { + return err + } + + return nil +} + +func (s *serializer) encodeParts(parts [][]byte) string { + partsHex := make([]string, len(parts)) + + for i, part := range parts { + partsHex[i] = hex.EncodeToString(part) + } + + return strings.Join(partsHex, partsSeparator) +} + +func (s *serializer) decodeIntoParts(encoded string) ([][]byte, error) { + partsHex := strings.Split(encoded, partsSeparator) + parts := make([][]byte, len(partsHex)) + + for i, partHex := range partsHex { + part, err := hex.DecodeString(partHex) + if err != nil { + return nil, err + } + + parts[i] = part + } + + return parts, nil +} diff --git a/abi/serializer/serializer_test.go b/abi/serializer/serializer_test.go new file mode 100644 index 00000000..db7e1817 --- /dev/null +++ b/abi/serializer/serializer_test.go @@ -0,0 +1,321 @@ +package serializer + +import ( + "testing" + + "github.com/multiversx/mx-sdk-go/abi/codec" + "github.com/multiversx/mx-sdk-go/abi/values" + "github.com/stretchr/testify/require" +) + +func TestSerializer_Serialize(t *testing.T) { + serializer := NewSerializer(codec.NewDefaultCodec()) + + t.Run("u8", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.U8Value{Value: 0x42}, + }) + + require.NoError(t, err) + require.Equal(t, "42", data) + }) + + t.Run("u16", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.U16Value{Value: 0x4243}, + }) + + require.NoError(t, err) + require.Equal(t, "4243", data) + }) + + t.Run("u8, u16", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.U8Value{Value: 0x42}, + values.U16Value{Value: 0x4243}, + }) + + require.NoError(t, err) + require.Equal(t, "42@4243", data) + }) + + t.Run("multi", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.InputMultiValue{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U16Value{Value: 0x4243}, + values.U32Value{Value: 0x42434445}, + }, + }, + }) + + require.NoError(t, err) + require.Equal(t, "42@4243@42434445", data) + }) + + t.Run("u8, multi", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.U8Value{Value: 0x42}, + values.InputMultiValue{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U16Value{Value: 0x4243}, + values.U32Value{Value: 0x42434445}, + }, + }, + }) + + require.NoError(t, err) + require.Equal(t, "42@42@4243@42434445", data) + }) + + t.Run("multi, multi>", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.InputMultiValue{ + Items: []any{ + values.InputMultiValue{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U16Value{Value: 0x4243}, + }, + }, + values.InputMultiValue{ + Items: []any{ + values.U8Value{Value: 0x44}, + values.U16Value{Value: 0x4445}, + }, + }, + }, + }, + }) + + require.NoError(t, err) + require.Equal(t, "42@4243@44@4445", data) + }) + + t.Run("variadic, of different types", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.InputVariadicValues{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U16Value{Value: 0x4243}, + }, + }, + }) + + // For now, the serializer does not perform such a strict type check. + // Although doable, it would be slightly complex and, if done, might be even dropped in the future + // (with respect to the decoder that is embedded in Rust-based smart contracts). + require.Nil(t, err) + require.Equal(t, "42@4243", data) + }) + + t.Run("variadic, u8: should err because variadic must be last", func(t *testing.T) { + _, err := serializer.Serialize([]any{ + values.InputVariadicValues{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U8Value{Value: 0x43}, + }, + }, + values.U8Value{Value: 0x44}, + }) + + require.ErrorContains(t, err, "variadic values must be last among input values") + }) + + t.Run("u8, variadic", func(t *testing.T) { + data, err := serializer.Serialize([]any{ + values.U8Value{Value: 0x41}, + values.InputVariadicValues{ + Items: []any{ + values.U8Value{Value: 0x42}, + values.U8Value{Value: 0x43}, + }, + }, + }) + + require.Nil(t, err) + require.Equal(t, "41@42@43", data) + }) +} + +func TestSerializer_Deserialize(t *testing.T) { + serializer := NewSerializer(codec.NewDefaultCodec()) + + t.Run("nil destination", func(t *testing.T) { + err := serializer.Deserialize("", []any{nil}) + require.ErrorContains(t, err, "cannot deserialize into nil value") + }) + + t.Run("u8", func(t *testing.T) { + outputValues := []any{ + &values.U8Value{}, + } + + err := serializer.Deserialize("42", outputValues) + + require.Nil(t, err) + require.Equal(t, []any{ + &values.U8Value{Value: 0x42}, + }, outputValues) + }) + + t.Run("u16", func(t *testing.T) { + outputValues := []any{ + &values.U16Value{}, + } + + err := serializer.Deserialize("4243", outputValues) + + require.Nil(t, err) + require.Equal(t, []any{ + &values.U16Value{Value: 0x4243}, + }, outputValues) + }) + + t.Run("u8, u16", func(t *testing.T) { + outputValues := []any{ + &values.U8Value{}, + &values.U16Value{}, + } + + err := serializer.Deserialize("42@4243", outputValues) + + require.Nil(t, err) + require.Equal(t, []any{ + &values.U8Value{Value: 0x42}, + &values.U16Value{Value: 0x4243}, + }, outputValues) + }) + + t.Run("multi", func(t *testing.T) { + outputValues := []any{ + &values.OutputMultiValue{ + Items: []any{ + &values.U8Value{}, + &values.U16Value{}, + &values.U32Value{}, + }, + }, + } + + err := serializer.Deserialize("42@4243@42434445", outputValues) + + require.Nil(t, err) + require.Equal(t, []any{ + &values.OutputMultiValue{ + Items: []any{ + &values.U8Value{Value: 0x42}, + &values.U16Value{Value: 0x4243}, + &values.U32Value{Value: 0x42434445}, + }, + }, + }, outputValues) + }) + + t.Run("u8, multi", func(t *testing.T) { + outputValues := []any{ + &values.U8Value{}, + &values.OutputMultiValue{ + Items: []any{ + &values.U8Value{}, + &values.U16Value{}, + &values.U32Value{}, + }, + }, + } + + err := serializer.Deserialize("42@42@4243@42434445", outputValues) + + require.Nil(t, err) + require.Equal(t, []any{ + &values.U8Value{Value: 0x42}, + &values.OutputMultiValue{ + Items: []any{ + &values.U8Value{Value: 0x42}, + &values.U16Value{Value: 0x4243}, + &values.U32Value{Value: 0x42434445}, + }, + }, + }, outputValues) + }) + + t.Run("variadic, should err because of nil item creator", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + } + + err := serializer.Deserialize("", []any{destination}) + require.ErrorContains(t, err, "cannot deserialize variadic values: item creator is nil") + }) + + t.Run("empty: u8", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + ItemCreator: func() any { return &values.U8Value{} }, + } + + err := serializer.Deserialize("", []any{destination}) + require.NoError(t, err) + require.Equal(t, []any{&values.U8Value{Value: 0}}, destination.Items) + }) + + t.Run("variadic", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + ItemCreator: func() any { return &values.U8Value{} }, + } + + err := serializer.Deserialize("2A@2B@2C", []any{destination}) + require.NoError(t, err) + + require.Equal(t, []any{ + &values.U8Value{Value: 42}, + &values.U8Value{Value: 43}, + &values.U8Value{Value: 44}, + }, destination.Items) + }) + + t.Run("varidic, with empty items", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + ItemCreator: func() any { return &values.U8Value{} }, + } + + err := serializer.Deserialize("@01@", []any{destination}) + require.NoError(t, err) + + require.Equal(t, []any{ + &values.U8Value{Value: 0}, + &values.U8Value{Value: 1}, + &values.U8Value{Value: 0}, + }, destination.Items) + }) + + t.Run("varidic", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + ItemCreator: func() any { return &values.U32Value{} }, + } + + err := serializer.Deserialize("AABBCCDD@DDCCBBAA", []any{destination}) + require.NoError(t, err) + + require.Equal(t, []any{ + &values.U32Value{Value: 0xAABBCCDD}, + &values.U32Value{Value: 0xDDCCBBAA}, + }, destination.Items) + }) + + t.Run("varidic, should err because decoded value is too large", func(t *testing.T) { + destination := &values.OutputVariadicValues{ + Items: []any{}, + ItemCreator: func() any { return &values.U8Value{} }, + } + + err := serializer.Deserialize("0100", []any{destination}) + require.ErrorContains(t, err, "cannot decode (top-level) *values.U8Value, because of: decoded value is too large: 256 > 255") + }) +} diff --git a/abi/values/values.go b/abi/values/values.go new file mode 100644 index 00000000..ea2b254a --- /dev/null +++ b/abi/values/values.go @@ -0,0 +1,108 @@ +package values + +import "math/big" + +type U8Value struct { + Value uint8 +} + +type U16Value struct { + Value uint16 +} + +type U32Value struct { + Value uint32 +} + +type U64Value struct { + Value uint64 +} + +type I8Value struct { + Value int8 +} + +type I16Value struct { + Value int16 +} + +type I32Value struct { + Value int32 +} + +type I64Value struct { + Value int64 +} + +type BigIntValue struct { + Value *big.Int +} + +type AddressValue struct { + Value []byte +} + +type BytesValue struct { + Value []byte +} + +type StringValue struct { + Value string +} + +type BoolValue struct { + Value bool +} + +type OptionValue struct { + Value any +} + +type Field struct { + Name string + Value any +} + +type StructValue struct { + Fields []Field +} + +type TupleValue struct { + Fields []Field +} + +type EnumValue struct { + Discriminant uint8 + Fields []Field +} + +type InputListValue struct { + Items []any +} + +type OutputListValue struct { + Items []any + ItemCreator func() any +} + +type InputMultiValue struct { + Items []any +} + +type InputVariadicValues struct { + Items []any +} + +type OutputMultiValue struct { + Items []any +} + +type OutputVariadicValues struct { + Items []any + ItemCreator func() any +} + +type OptionalValue struct { + Value any + IsSet bool +}