From 5beb08c033424584aa5cb29b3128cd12165dc881 Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Wed, 28 Aug 2024 18:01:16 +0530 Subject: [PATCH 1/6] Support for Parameterized type * Separate AnyType. This will be helpful in match method * Added support for ParameterizedFixedChar/VarChar/FixedBinary/Decimal * Added parser support for Parameterized/PrecisionTimestamp/PrecisionTimestampTz --- extensions/simple_extension.go | 15 +++ types/any_type.go | 58 +++++++++++ types/any_type_test.go | 33 +++++++ types/parameterized_decimal_type.go | 55 +++++++++++ types/parameterized_types.go | 109 +++++++++++++++++++++ types/parameterized_types_test.go | 66 +++++++++++++ types/parser/type_parser.go | 145 +++++++++++++++++++++++----- types/parser/type_parser_test.go | 8 ++ types/precison_timestamp_types.go | 13 +++ types/types.go | 60 +++++++----- 10 files changed, 516 insertions(+), 46 deletions(-) create mode 100644 types/any_type.go create mode 100644 types/any_type_test.go create mode 100644 types/parameterized_decimal_type.go create mode 100644 types/parameterized_types.go create mode 100644 types/parameterized_types_test.go diff --git a/extensions/simple_extension.go b/extensions/simple_extension.go index b191ea1..d629c1e 100644 --- a/extensions/simple_extension.go +++ b/extensions/simple_extension.go @@ -3,11 +3,13 @@ package extensions import ( + "errors" "fmt" "reflect" "strings" substraitgo "github.com/substrait-io/substrait-go" + "github.com/substrait-io/substrait-go/types" "github.com/substrait-io/substrait-go/types/parser" ) @@ -57,6 +59,7 @@ type TypeVariation struct { type Argument interface { toTypeString() string + ArgType() (types.Type, error) } type EnumArg struct { @@ -69,6 +72,10 @@ func (EnumArg) toTypeString() string { return "req" } +func (EnumArg) ArgType() (types.Type, error) { + return nil, errors.New("unimplemented") +} + type ValueArg struct { Name string `yaml:",omitempty"` Description string `yaml:",omitempty"` @@ -80,6 +87,10 @@ func (v ValueArg) toTypeString() string { return v.Value.Expr.(*parser.Type).ShortType() } +func (v ValueArg) ArgType() (types.Type, error) { + return v.Value.Expr.(*parser.Type).Type() +} + type TypeArg struct { Name string `yaml:",omitempty"` Description string `yaml:",omitempty"` @@ -88,6 +99,10 @@ type TypeArg struct { func (TypeArg) toTypeString() string { return "type" } +func (TypeArg) ArgType() (types.Type, error) { + return nil, errors.New("unimplemented") +} + type ArgumentList []Argument func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error { diff --git a/types/any_type.go b/types/any_type.go new file mode 100644 index 0000000..7d7f2e8 --- /dev/null +++ b/types/any_type.go @@ -0,0 +1,58 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/proto" +) + +// AnyType to represent AnyType, this type is to indicate "any" type of argument +// This type is not used in function invocation. It is only used in function definition +type AnyType struct { + Name string + Nullability Nullability +} + +func (*AnyType) isRootRef() {} +func (m *AnyType) WithNullability(nullability Nullability) Type { + m.Nullability = nullability + return m +} +func (m *AnyType) GetType() Type { return m } +func (m *AnyType) GetNullability() Nullability { + return m.Nullability +} +func (*AnyType) GetTypeVariationReference() uint32 { + panic("not allowed") +} +func (*AnyType) Equals(rhs Type) bool { + // equal to every other type + return true +} + +func (*AnyType) ToProtoFuncArg() *proto.FunctionArgument { + panic("not allowed") +} + +func (*AnyType) ToProto() *proto.Type { + panic("not allowed") +} + +func (t *AnyType) ShortString() string { return t.Name } +func (t *AnyType) String() string { + return fmt.Sprintf("%s%s", t.Name, strNullable(t)) +} + +// Below methods are for parser Def interface + +func (*AnyType) Optional() bool { + panic("not allowed") +} + +func (m *AnyType) ShortType() string { + return "any" +} + +func (m *AnyType) Type() (Type, error) { + return m, nil +} diff --git a/types/any_type_test.go b/types/any_type_test.go new file mode 100644 index 0000000..8362617 --- /dev/null +++ b/types/any_type_test.go @@ -0,0 +1,33 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" +) + +func TestAnyType(t *testing.T) { + for _, td := range []struct { + testName string + argName string + nullability types.Nullability + expectedString string + }{ + {"any", "any", types.NullabilityNullable, "any?"}, + {"anyrequired", "any", types.NullabilityRequired, "any"}, + {"anyOtherName", "any1", types.NullabilityNullable, "any1?"}, + {"T name", "T", types.NullabilityNullable, "T?"}, + } { + t.Run(td.testName, func(t *testing.T) { + arg := &types.AnyType{ + Name: td.argName, + Nullability: td.nullability, + } + require.Equal(t, td.expectedString, arg.String()) + require.Equal(t, td.nullability, arg.GetNullability()) + require.Equal(t, td.argName, arg.ShortString()) + require.Equal(t, "any", arg.ShortType()) + }) + } +} diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go new file mode 100644 index 0000000..5a08f29 --- /dev/null +++ b/types/parameterized_decimal_type.go @@ -0,0 +1,55 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/proto" +) + +type ParameterizedDecimalType struct { + Nullability Nullability + TypeVariationRef uint32 + Precision IntegerParam + Scale IntegerParam +} + +func (*ParameterizedDecimalType) isRootRef() {} +func (m *ParameterizedDecimalType) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m *ParameterizedDecimalType) GetType() Type { return m } +func (m *ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedDecimalType) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m *ParameterizedDecimalType) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedDecimalType); ok { + return *o == *m + } + return false +} + +func (*ParameterizedDecimalType) ToProtoFuncArg() *proto.FunctionArgument { + // parameterized type are never on wire so to proto is not supported + panic("not supported") +} + +func (m *ParameterizedDecimalType) ShortString() string { + t := &DecimalType{} + return t.ShortString() +} + +func (m *ParameterizedDecimalType) String() string { + return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) +} + +func (m *ParameterizedDecimalType) ParameterString() string { + return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) +} + +func (m *ParameterizedDecimalType) BaseString() string { + t := &DecimalType{} + return t.BaseString() +} diff --git a/types/parameterized_types.go b/types/parameterized_types.go new file mode 100644 index 0000000..0a88e76 --- /dev/null +++ b/types/parameterized_types.go @@ -0,0 +1,109 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/proto" +) + +type IntegerParam struct { + Name string +} + +func (m IntegerParam) Equals(o IntegerParam) bool { + return m == o +} + +func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter { + panic("not implemented") +} + +func (m *IntegerParam) String() string { + return m.Name +} + +type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { + Nullability Nullability + TypeVariationRef uint32 + IntegerOption IntegerParam +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType { + m.IntegerOption = integerOption + return m +} + +func (*ParameterizedTypeSingleIntegerParam[T]) isRootRef() {} +func (m *ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } +func (m *ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m *ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedTypeSingleIntegerParam[T]); ok { + return *o == *m + } + return false +} + +func (*ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument { + // parameterized type are never on wire so to proto is not supported + panic("not supported") +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) ShortString() string { + switch any(m).(type) { + case *ParameterizedVarCharType: + t := &VarCharType{} + return t.ShortString() + case *ParameterizedFixedCharType: + t := &FixedCharType{} + return t.ShortString() + case *ParameterizedFixedBinaryType: + t := &FixedBinaryType{} + return t.ShortString() + case *ParameterizedPrecisionTimestampType: + t := &PrecisionTimestampType{} + return t.ShortString() + case *ParameterizedPrecisionTimestampTzType: + t := &PrecisionTimestampTzType{} + return t.ShortString() + default: + panic("unknown type") + } +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) String() string { + return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) ParameterString() string { + return fmt.Sprintf("<%s>", m.IntegerOption.String()) +} + +func (m *ParameterizedTypeSingleIntegerParam[T]) BaseString() string { + switch any(m).(type) { + case *ParameterizedVarCharType: + t := &VarCharType{} + return t.BaseString() + case *ParameterizedFixedCharType: + t := &FixedCharType{} + return t.BaseString() + case *ParameterizedFixedBinaryType: + t := &FixedBinaryType{} + return t.BaseString() + case *ParameterizedPrecisionTimestampType: + t := &PrecisionTimestampType{} + return t.BaseString() + case *ParameterizedPrecisionTimestampTzType: + t := &PrecisionTimestampTzType{} + return t.BaseString() + default: + panic("unknown type") + } +} diff --git a/types/parameterized_types_test.go b/types/parameterized_types_test.go new file mode 100644 index 0000000..eba0f86 --- /dev/null +++ b/types/parameterized_types_test.go @@ -0,0 +1,66 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" +) + +func TestParameterizedVarCharType(t *testing.T) { + for _, td := range []struct { + name string + typ types.ParameterizedSingleIntegerType + nullability types.Nullability + integerOption types.IntegerParam + expectedString string + expectedBaseString string + expectedShortString string + }{ + {"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "varchar?", "varchar", "vchar"}, + {"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "varchar", "varchar", "vchar"}, + {"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "char?", "char", "fchar"}, + {"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "char", "char", "fchar"}, + {"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "fixedbinary?", "fixedbinary", "fbin"}, + {"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "fixedbinary", "fixedbinary", "fbin"}, + {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp?", "precision_timestamp", "prets"}, + {"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp", "precision_timestamp", "prets"}, + {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz?", "precision_timestamp_tz", "pretstz"}, + {"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz", "precision_timestamp_tz", "pretstz"}, + } { + t.Run(td.name, func(t *testing.T) { + pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability) + require.Equal(t, td.expectedString, pt.String()) + parameterizeType, ok := pt.(types.ParameterizedType) + require.True(t, ok) + require.Equal(t, td.expectedBaseString, parameterizeType.BaseString()) + require.Equal(t, td.expectedShortString, pt.ShortString()) + require.True(t, pt.Equals(pt)) + }) + } +} + +func TestParameterizedDecimalType(t *testing.T) { + for _, td := range []struct { + name string + precision string + scale string + nullability types.Nullability + expectedString string + expectedBaseString string + expectedShortString string + }{ + {"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?", "decimal", "dec"}, + {"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal", "decimal", "dec"}, + } { + t.Run(td.name, func(t *testing.T) { + precision := types.IntegerParam{Name: td.precision} + scale := types.IntegerParam{Name: td.scale} + pt := &types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability} + require.Equal(t, td.expectedString, pt.String()) + require.Equal(t, td.expectedBaseString, pt.BaseString()) + require.Equal(t, td.expectedShortString, pt.ShortString()) + require.True(t, pt.Equals(pt)) + }) + } +} diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index ae5e4a2..6a3b0b9 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -3,6 +3,7 @@ package parser import ( + "errors" "io" "strconv" "strings" @@ -21,6 +22,14 @@ type TypeExpression struct { func (t TypeExpression) String() string { return t.Expr.String() } +func (t TypeExpression) Type() (types.Type, error) { + typeDef, ok := t.Expr.(Def) + if !ok { + return nil, errors.New("type expression doesn't represent type") + } + return typeDef.Type() +} + func (t TypeExpression) MarshalYAML() (interface{}, error) { return t.Expr.String(), nil } @@ -104,7 +113,7 @@ func (t *typename) Capture(values []string) error { } type nonParamType struct { - TypeName typename `parser:"@(AnyType | Template | IntType | Boolean | FPType | Temporal | BinaryType)"` + TypeName typename `parser:"@(IntType | Boolean | FPType | Temporal | BinaryType)"` Nullability bool `parser:"@'?'?"` // Variation int `parser:"'[' @\d+ ']'?"` } @@ -120,10 +129,6 @@ func (t *nonParamType) String() string { } func (t *nonParamType) ShortType() string { - if strings.HasPrefix(string(t.TypeName), "any") { - return "any" - } - return types.GetShortTypeName(types.TypeName(t.TypeName)) } @@ -188,6 +193,10 @@ func (p *lengthType) ShortType() string { switch p.TypeName { case "fixedchar", "varchar", "fixedbinary": return types.GetShortTypeName(types.TypeName(p.TypeName)) + case "precision_timestamp": + return "prets" + case "precision_timestamp_tz": + return "pretstz" } return "" } @@ -200,16 +209,62 @@ func (p *lengthType) Optional() bool { return false } func (p *lengthType) Type() (types.Type, error) { var n types.Nullability - lit, ok := p.NumericParam.Expr.(*IntegerLiteral) - if !ok { + + var typ types.Type + var err error + switch t := p.NumericParam.Expr.(type) { + case *IntegerLiteral: + typ, err = getFixedTypeFromConcreteParam(p.TypeName, t) + case *ParamName: + typ, err = getParameterizedTypeSingleParam(p.TypeName, t) + default: return nil, substraitgo.ErrNotImplemented } + if err != nil { + return nil, err + } + return typ.WithNullability(n), nil +} - typ, err := types.FixedTypeNameToType(types.TypeName(p.TypeName)) +func getFixedTypeFromConcreteParam(name string, param *IntegerLiteral) (types.Type, error) { + typeName := types.TypeName(name) + switch typeName { + case types.TypeNamePrecisionTimestamp: + precision, err := types.ProtoToTimePrecision(param.Value) + if err != nil { + return nil, err + } + return types.NewPrecisionTimestampType(precision), nil + case types.TypeNamePrecisionTimestampTz: + precision, err := types.ProtoToTimePrecision(param.Value) + if err != nil { + return nil, err + } + return types.NewPrecisionTimestampTzType(precision), nil + } + typ, err := types.FixedTypeNameToType(typeName) if err != nil { return nil, err } - return typ.WithLength(lit.Value).WithNullability(n), nil + return typ.WithLength(param.Value), nil +} + +func getParameterizedTypeSingleParam(typeName string, param *ParamName) (types.Type, error) { + intParam := types.IntegerParam{Name: param.Name} + switch types.TypeName(typeName) { + case types.TypeNameVarChar: + return &types.ParameterizedVarCharType{IntegerOption: intParam}, nil + case types.TypeNameFixedChar: + return &types.ParameterizedFixedCharType{IntegerOption: intParam}, nil + case types.TypeNameFixedBinary: + return &types.ParameterizedFixedBinaryType{IntegerOption: intParam}, nil + case types.TypeNamePrecisionTimestamp: + return &types.ParameterizedPrecisionTimestampType{IntegerOption: intParam}, nil + case types.TypeNamePrecisionTimestampTz: + return &types.ParameterizedPrecisionTimestampTzType{IntegerOption: intParam}, nil + default: + return nil, substraitgo.ErrNotImplemented + } } type decimalType struct { @@ -232,21 +287,28 @@ func (d *decimalType) Optional() bool { return d.Nullability } func (d *decimalType) Type() (types.Type, error) { var n types.Nullability - p, ok := d.Precision.Expr.(*IntegerLiteral) - if !ok { - return nil, substraitgo.ErrNotImplemented + pi, ok1 := d.Precision.Expr.(*IntegerLiteral) + si, ok2 := d.Scale.Expr.(*IntegerLiteral) + if ok1 && ok2 { + // concrete decimal param + return &types.DecimalType{ + Nullability: n, + Precision: pi.Value, + Scale: si.Value, + }, nil } - s, ok := d.Scale.Expr.(*IntegerLiteral) - if !ok { - return nil, substraitgo.ErrNotImplemented + ps, ok1 := d.Precision.Expr.(*ParamName) + ss, ok2 := d.Scale.Expr.(*ParamName) + if ok1 && ok2 { + // parameterized decimal param + return &types.ParameterizedDecimalType{ + Nullability: n, + Precision: types.IntegerParam{Name: ps.Name}, + Scale: types.IntegerParam{Name: ss.Name}, + }, nil } - - return &types.DecimalType{ - Nullability: n, - Precision: p.Value, - Scale: s.Value, - }, nil + return nil, substraitgo.ErrNotImplemented } type structType struct { @@ -352,6 +414,43 @@ func (m *mapType) Type() (types.Type, error) { }, nil } +// parser token for any +type anyType struct { + TypeName typename `parser:"@(AnyType|Template)"` + Nullability bool `parser:"@'?'?"` +} + +func (anyType) Optional() bool { return false } + +func (t anyType) String() string { + opt := string(t.TypeName) + if t.Nullability { + opt += "?" + } + return opt +} + +func (t anyType) ShortType() string { + if strings.HasPrefix(string(t.TypeName), "any") { + return "any" + } + return string(t.TypeName) +} + +func (t anyType) Type() (types.Type, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + typeName := string(t.TypeName) + if strings.HasPrefix(typeName, "any") { + return &types.AnyType{Name: "any", Nullability: n}, nil + } + return &types.AnyType{Name: typeName, Nullability: n}, nil +} + var ( def = lexer.MustSimple([]lexer.SimpleRule{ {Name: "whitespace", Pattern: `[ \t]+`}, @@ -362,7 +461,7 @@ var ( {Name: "FPType", Pattern: `fp(32|64)`}, {Name: "Temporal", Pattern: `timestamp(_tz)?|date|time|interval_day|interval_year`}, {Name: "BinaryType", Pattern: `string|binary|uuid`}, - {Name: "LengthType", Pattern: `fixedchar|varchar|fixedbinary`}, + {Name: "LengthType", Pattern: `fixedchar|varchar|fixedbinary|precision_timestamp_tz|precision_timestamp`}, {Name: "Int", Pattern: `[-+]?\d+`}, {Name: "ParamType", Pattern: `(?i)(struct|list|decimal|map)`}, {Name: "Identifier", Pattern: `[a-zA-Z_$][a-zA-Z_$0-9]*`}, @@ -389,7 +488,7 @@ func (p *Parser) ParseBytes(expr []byte) (*TypeExpression, error) { func New() (*Parser, error) { parser, err := participle.Build[TypeExpression]( participle.Union[Expression](&Type{}, &IntegerLiteral{}, &ParamName{}), - participle.Union[Def](&nonParamType{}, &mapType{}, &listType{}, &structType{}, &lengthType{}, &decimalType{}), + participle.Union[Def](&anyType{}, &nonParamType{}, &mapType{}, &listType{}, &structType{}, &lengthType{}, &decimalType{}), participle.CaseInsensitive("Boolean", "ParamType", "IntType", "FPType", "Temporal", "BinaryType", "LengthType"), participle.Lexer(def), participle.UseLookahead(3), diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 0abc300..20390d4 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -29,6 +29,14 @@ func TestParser(t *testing.T) { {"struct", "struct", "struct", &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}, + {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}}, + {"varchar", "varchar", "vchar", &types.ParameterizedVarCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"fixedchar", "fixedchar", "fchar", &types.ParameterizedFixedCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"fixedbinary", "fixedbinary", "fbin", &types.ParameterizedFixedBinaryType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"precision_timestamp", "precision_timestamp", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}}}, } p, err := parser.New() diff --git a/types/precison_timestamp_types.go b/types/precison_timestamp_types.go index e27dd0c..e88d656 100644 --- a/types/precison_timestamp_types.go +++ b/types/precison_timestamp_types.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "reflect" "github.com/substrait-io/substrait-go/proto" ) @@ -97,6 +98,14 @@ func (m *PrecisionTimestampType) String() string { m.Precision.ToProtoVal()) } +func (m *PrecisionTimestampType) ParameterString() string { + return fmt.Sprintf("%d", m.Precision.ToProtoVal()) +} + +func (m *PrecisionTimestampType) BaseString() string { + return typeNames[reflect.TypeOf(m)] +} + // PrecisionTimestampTzType this is used to represent a type of Precision timestamp with TimeZone type PrecisionTimestampTzType struct { PrecisionTimestampType @@ -145,3 +154,7 @@ func (m *PrecisionTimestampTzType) Equals(rhs Type) bool { return false } func (*PrecisionTimestampTzType) ShortString() string { return "pretstz" } + +func (m *PrecisionTimestampTzType) BaseString() string { + return typeNames[reflect.TypeOf(m)] +} diff --git a/types/types.go b/types/types.go index a837860..7e2b8b0 100644 --- a/types/types.go +++ b/types/types.go @@ -44,10 +44,12 @@ const ( TypeNameIntervalDay TypeName = "interval_day" TypeNameUUID TypeName = "uuid" - TypeNameFixedBinary TypeName = "fixedbinary" - TypeNameFixedChar TypeName = "fixedchar" - TypeNameVarChar TypeName = "varchar" - TypeNameDecimal TypeName = "decimal" + TypeNameFixedBinary TypeName = "fixedbinary" + TypeNameFixedChar TypeName = "fixedchar" + TypeNameVarChar TypeName = "varchar" + TypeNameDecimal TypeName = "decimal" + TypeNamePrecisionTimestamp TypeName = "precision_timestamp" + TypeNamePrecisionTimestampTz TypeName = "precision_timestamp_tz" ) var simpleTypeNameMap = map[TypeName]Type{ @@ -379,6 +381,11 @@ type ( ParameterizedType WithLength(int32) FixedType } + + ParameterizedSingleIntegerType interface { + ParameterizedType + WithIntegerOption(param IntegerParam) ParameterizedSingleIntegerType + } ) // TypeToProto properly constructs the appropriate protobuf message @@ -523,6 +530,8 @@ var typeNames = map[reflect.Type]string{ reflect.TypeOf(&FixedBinary{}): "fixedbinary", reflect.TypeOf(&emptyFixedChar): "char", reflect.TypeOf(&VarChar{}): "varchar", + reflect.TypeOf(&PrecisionTimestampType{}): "precision_timestamp", + reflect.TypeOf(&PrecisionTimestampTzType{}): "precision_timestamp_tz", } var shortNames = map[reflect.Type]string{ @@ -604,25 +613,30 @@ func (s *PrimitiveType[T]) String() string { // create type aliases to the generic structs type ( - BooleanType = PrimitiveType[bool] - Int8Type = PrimitiveType[int8] - Int16Type = PrimitiveType[int16] - Int32Type = PrimitiveType[int32] - Int64Type = PrimitiveType[int64] - Float32Type = PrimitiveType[float32] - Float64Type = PrimitiveType[float64] - StringType = PrimitiveType[string] - BinaryType = PrimitiveType[[]byte] - TimestampType = PrimitiveType[Timestamp] - DateType = PrimitiveType[Date] - TimeType = PrimitiveType[Time] - TimestampTzType = PrimitiveType[TimestampTz] - IntervalYearType = PrimitiveType[IntervalYearToMonth] - IntervalDayType = PrimitiveType[IntervalDayToSecond] - UUIDType = PrimitiveType[UUID] - FixedCharType = FixedLenType[FixedChar] - VarCharType = FixedLenType[VarChar] - FixedBinaryType = FixedLenType[FixedBinary] + BooleanType = PrimitiveType[bool] + Int8Type = PrimitiveType[int8] + Int16Type = PrimitiveType[int16] + Int32Type = PrimitiveType[int32] + Int64Type = PrimitiveType[int64] + Float32Type = PrimitiveType[float32] + Float64Type = PrimitiveType[float64] + StringType = PrimitiveType[string] + BinaryType = PrimitiveType[[]byte] + TimestampType = PrimitiveType[Timestamp] + DateType = PrimitiveType[Date] + TimeType = PrimitiveType[Time] + TimestampTzType = PrimitiveType[TimestampTz] + IntervalYearType = PrimitiveType[IntervalYearToMonth] + IntervalDayType = PrimitiveType[IntervalDayToSecond] + UUIDType = PrimitiveType[UUID] + FixedCharType = FixedLenType[FixedChar] + VarCharType = FixedLenType[VarChar] + FixedBinaryType = FixedLenType[FixedBinary] + ParameterizedVarCharType = ParameterizedTypeSingleIntegerParam[VarCharType] + ParameterizedFixedCharType = ParameterizedTypeSingleIntegerParam[FixedCharType] + ParameterizedFixedBinaryType = ParameterizedTypeSingleIntegerParam[FixedBinaryType] + ParameterizedPrecisionTimestampType = ParameterizedTypeSingleIntegerParam[PrecisionTimestampType] + ParameterizedPrecisionTimestampTzType = ParameterizedTypeSingleIntegerParam[PrecisionTimestampTzType] ) // FixedLenType is any of the types which also need to track their specific From a1234111f5cf4b2b6a9dd6423669f79268726254 Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Mon, 2 Sep 2024 13:49:43 +0530 Subject: [PATCH 2/6] Address review comments --- types/any_type.go | 26 +++++----- types/parameterized_decimal_type.go | 33 +++++++------ types/parameterized_types.go | 73 +++++++++++++++-------------- types/parameterized_types_test.go | 2 +- types/parser/type_parser.go | 16 +++---- types/parser/type_parser_test.go | 36 +++++++------- types/types.go | 1 + 7 files changed, 98 insertions(+), 89 deletions(-) diff --git a/types/any_type.go b/types/any_type.go index 7d7f2e8..2fbf352 100644 --- a/types/any_type.go +++ b/types/any_type.go @@ -13,46 +13,46 @@ type AnyType struct { Nullability Nullability } -func (*AnyType) isRootRef() {} -func (m *AnyType) WithNullability(nullability Nullability) Type { +func (AnyType) isRootRef() {} +func (m AnyType) WithNullability(nullability Nullability) Type { m.Nullability = nullability return m } -func (m *AnyType) GetType() Type { return m } -func (m *AnyType) GetNullability() Nullability { +func (m AnyType) GetType() Type { return m } +func (m AnyType) GetNullability() Nullability { return m.Nullability } -func (*AnyType) GetTypeVariationReference() uint32 { +func (AnyType) GetTypeVariationReference() uint32 { panic("not allowed") } -func (*AnyType) Equals(rhs Type) bool { +func (AnyType) Equals(rhs Type) bool { // equal to every other type return true } -func (*AnyType) ToProtoFuncArg() *proto.FunctionArgument { +func (AnyType) ToProtoFuncArg() *proto.FunctionArgument { panic("not allowed") } -func (*AnyType) ToProto() *proto.Type { +func (AnyType) ToProto() *proto.Type { panic("not allowed") } -func (t *AnyType) ShortString() string { return t.Name } -func (t *AnyType) String() string { +func (t AnyType) ShortString() string { return t.Name } +func (t AnyType) String() string { return fmt.Sprintf("%s%s", t.Name, strNullable(t)) } // Below methods are for parser Def interface -func (*AnyType) Optional() bool { +func (AnyType) Optional() bool { panic("not allowed") } -func (m *AnyType) ShortType() string { +func (m AnyType) ShortType() string { return "any" } -func (m *AnyType) Type() (Type, error) { +func (m AnyType) Type() (Type, error) { return m, nil } diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go index 5a08f29..d6946e6 100644 --- a/types/parameterized_decimal_type.go +++ b/types/parameterized_decimal_type.go @@ -6,6 +6,9 @@ import ( "github.com/substrait-io/substrait-go/proto" ) +// ParameterizedDecimalType is a decimal type with precision and scale parameters of string type +// example: Decimal(P,S). Kindly note concrete types Decimal(10, 2) are not represented by this type +// Concrete type is represented by DecimalType type ParameterizedDecimalType struct { Nullability Nullability TypeVariationRef uint32 @@ -13,43 +16,43 @@ type ParameterizedDecimalType struct { Scale IntegerParam } -func (*ParameterizedDecimalType) isRootRef() {} -func (m *ParameterizedDecimalType) WithNullability(n Nullability) Type { +func (ParameterizedDecimalType) isRootRef() {} +func (m ParameterizedDecimalType) WithNullability(n Nullability) Type { m.Nullability = n return m } -func (m *ParameterizedDecimalType) GetType() Type { return m } -func (m *ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedDecimalType) GetTypeVariationReference() uint32 { +func (m ParameterizedDecimalType) GetType() Type { return m } +func (m ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } +func (m ParameterizedDecimalType) GetTypeVariationReference() uint32 { return m.TypeVariationRef } -func (m *ParameterizedDecimalType) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedDecimalType); ok { - return *o == *m +func (m ParameterizedDecimalType) Equals(rhs Type) bool { + if o, ok := rhs.(ParameterizedDecimalType); ok { + return o == m } return false } -func (*ParameterizedDecimalType) ToProtoFuncArg() *proto.FunctionArgument { +func (ParameterizedDecimalType) ToProtoFuncArg() *proto.FunctionArgument { // parameterized type are never on wire so to proto is not supported panic("not supported") } -func (m *ParameterizedDecimalType) ShortString() string { - t := &DecimalType{} +func (m ParameterizedDecimalType) ShortString() string { + t := DecimalType{} return t.ShortString() } -func (m *ParameterizedDecimalType) String() string { +func (m ParameterizedDecimalType) String() string { return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) } -func (m *ParameterizedDecimalType) ParameterString() string { +func (m ParameterizedDecimalType) ParameterString() string { return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) } -func (m *ParameterizedDecimalType) BaseString() string { - t := &DecimalType{} +func (m ParameterizedDecimalType) BaseString() string { + t := DecimalType{} return t.BaseString() } diff --git a/types/parameterized_types.go b/types/parameterized_types.go index 0a88e76..5ebf146 100644 --- a/types/parameterized_types.go +++ b/types/parameterized_types.go @@ -6,6 +6,8 @@ import ( "github.com/substrait-io/substrait-go/proto" ) +// IntegerParam represents a single integer parameter for a parameterized type +// Example: VARCHAR(L1) -> L1 is the integer parameter type IntegerParam struct { Name string } @@ -18,90 +20,91 @@ func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter { panic("not implemented") } -func (m *IntegerParam) String() string { +func (m IntegerParam) String() string { return m.Name } +// ParameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { Nullability Nullability TypeVariationRef uint32 IntegerOption IntegerParam } -func (m *ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType { +func (m ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType { m.IntegerOption = integerOption return m } -func (*ParameterizedTypeSingleIntegerParam[T]) isRootRef() {} -func (m *ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { +func (ParameterizedTypeSingleIntegerParam[T]) isRootRef() {} +func (m ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { m.Nullability = n return m } -func (m *ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } -func (m *ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { +func (m ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } +func (m ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } +func (m ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { return m.TypeVariationRef } -func (m *ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedTypeSingleIntegerParam[T]); ok { - return *o == *m +func (m ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { + if o, ok := rhs.(ParameterizedTypeSingleIntegerParam[T]); ok { + return o == m } return false } -func (*ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument { +func (ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument { // parameterized type are never on wire so to proto is not supported panic("not supported") } -func (m *ParameterizedTypeSingleIntegerParam[T]) ShortString() string { +func (m ParameterizedTypeSingleIntegerParam[T]) ShortString() string { switch any(m).(type) { - case *ParameterizedVarCharType: - t := &VarCharType{} + case ParameterizedVarCharType: + t := VarCharType{} return t.ShortString() - case *ParameterizedFixedCharType: - t := &FixedCharType{} + case ParameterizedFixedCharType: + t := FixedCharType{} return t.ShortString() - case *ParameterizedFixedBinaryType: - t := &FixedBinaryType{} + case ParameterizedFixedBinaryType: + t := FixedBinaryType{} return t.ShortString() - case *ParameterizedPrecisionTimestampType: - t := &PrecisionTimestampType{} + case ParameterizedPrecisionTimestampType: + t := PrecisionTimestampType{} return t.ShortString() - case *ParameterizedPrecisionTimestampTzType: - t := &PrecisionTimestampTzType{} + case ParameterizedPrecisionTimestampTzType: + t := PrecisionTimestampTzType{} return t.ShortString() default: panic("unknown type") } } -func (m *ParameterizedTypeSingleIntegerParam[T]) String() string { +func (m ParameterizedTypeSingleIntegerParam[T]) String() string { return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) } -func (m *ParameterizedTypeSingleIntegerParam[T]) ParameterString() string { +func (m ParameterizedTypeSingleIntegerParam[T]) ParameterString() string { return fmt.Sprintf("<%s>", m.IntegerOption.String()) } -func (m *ParameterizedTypeSingleIntegerParam[T]) BaseString() string { +func (m ParameterizedTypeSingleIntegerParam[T]) BaseString() string { switch any(m).(type) { - case *ParameterizedVarCharType: - t := &VarCharType{} + case ParameterizedVarCharType: + t := VarCharType{} return t.BaseString() - case *ParameterizedFixedCharType: - t := &FixedCharType{} + case ParameterizedFixedCharType: + t := FixedCharType{} return t.BaseString() - case *ParameterizedFixedBinaryType: - t := &FixedBinaryType{} + case ParameterizedFixedBinaryType: + t := FixedBinaryType{} return t.BaseString() - case *ParameterizedPrecisionTimestampType: - t := &PrecisionTimestampType{} + case ParameterizedPrecisionTimestampType: + t := PrecisionTimestampType{} return t.BaseString() - case *ParameterizedPrecisionTimestampTzType: - t := &PrecisionTimestampTzType{} + case ParameterizedPrecisionTimestampTzType: + t := PrecisionTimestampTzType{} return t.BaseString() default: panic("unknown type") diff --git a/types/parameterized_types_test.go b/types/parameterized_types_test.go index eba0f86..c0d9845 100644 --- a/types/parameterized_types_test.go +++ b/types/parameterized_types_test.go @@ -56,7 +56,7 @@ func TestParameterizedDecimalType(t *testing.T) { t.Run(td.name, func(t *testing.T) { precision := types.IntegerParam{Name: td.precision} scale := types.IntegerParam{Name: td.scale} - pt := &types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability} + pt := types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability} require.Equal(t, td.expectedString, pt.String()) require.Equal(t, td.expectedBaseString, pt.BaseString()) require.Equal(t, td.expectedShortString, pt.ShortString()) diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index 6a3b0b9..a338e2b 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -253,15 +253,15 @@ func getParameterizedTypeSingleParam(typeName string, param *ParamName) (types.T intParam := types.IntegerParam{Name: param.Name} switch types.TypeName(typeName) { case types.TypeNameVarChar: - return &types.ParameterizedVarCharType{IntegerOption: intParam}, nil + return types.ParameterizedVarCharType{IntegerOption: intParam}, nil case types.TypeNameFixedChar: - return &types.ParameterizedFixedCharType{IntegerOption: intParam}, nil + return types.ParameterizedFixedCharType{IntegerOption: intParam}, nil case types.TypeNameFixedBinary: - return &types.ParameterizedFixedBinaryType{IntegerOption: intParam}, nil + return types.ParameterizedFixedBinaryType{IntegerOption: intParam}, nil case types.TypeNamePrecisionTimestamp: - return &types.ParameterizedPrecisionTimestampType{IntegerOption: intParam}, nil + return types.ParameterizedPrecisionTimestampType{IntegerOption: intParam}, nil case types.TypeNamePrecisionTimestampTz: - return &types.ParameterizedPrecisionTimestampTzType{IntegerOption: intParam}, nil + return types.ParameterizedPrecisionTimestampTzType{IntegerOption: intParam}, nil default: return nil, substraitgo.ErrNotImplemented } @@ -302,7 +302,7 @@ func (d *decimalType) Type() (types.Type, error) { ss, ok2 := d.Scale.Expr.(*ParamName) if ok1 && ok2 { // parameterized decimal param - return &types.ParameterizedDecimalType{ + return types.ParameterizedDecimalType{ Nullability: n, Precision: types.IntegerParam{Name: ps.Name}, Scale: types.IntegerParam{Name: ss.Name}, @@ -446,9 +446,9 @@ func (t anyType) Type() (types.Type, error) { } typeName := string(t.TypeName) if strings.HasPrefix(typeName, "any") { - return &types.AnyType{Name: "any", Nullability: n}, nil + return types.AnyType{Name: "any", Nullability: n}, nil } - return &types.AnyType{Name: typeName, Nullability: n}, nil + return types.AnyType{Name: typeName, Nullability: n}, nil } var ( diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 20390d4..9cc6122 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -3,6 +3,7 @@ package parser_test import ( + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -13,10 +14,10 @@ import ( func TestParser(t *testing.T) { tests := []struct { - expr string - expected string - shortName string - typ types.Type + expr string + expected string + shortName string + expectedTyp types.Type }{ {"2", "2", "", nil}, {"-2", "-2", "", nil}, @@ -31,27 +32,28 @@ func TestParser(t *testing.T) { {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}, {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}}, - {"varchar", "varchar", "vchar", &types.ParameterizedVarCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"fixedchar", "fixedchar", "fchar", &types.ParameterizedFixedCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"fixedbinary", "fixedbinary", "fbin", &types.ParameterizedFixedBinaryType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"precision_timestamp", "precision_timestamp", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}}}, + {"varchar", "varchar", "vchar", types.ParameterizedVarCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"fixedchar", "fixedchar", "fchar", types.ParameterizedFixedCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"fixedbinary", "fixedbinary", "fbin", types.ParameterizedFixedBinaryType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"precision_timestamp", "precision_timestamp", "prets", types.ParameterizedPrecisionTimestampType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", types.ParameterizedPrecisionTimestampTzType{IntegerOption: types.IntegerParam{Name: "L1"}}}, + {"decimal", "decimal", "dec", types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}}}, } p, err := parser.New() require.NoError(t, err) - for _, tt := range tests { - t.Run(tt.expr, func(t *testing.T) { - d, err := p.ParseString(tt.expr) + for _, td := range tests { + t.Run(td.expr, func(t *testing.T) { + d, err := p.ParseString(td.expr) assert.NoError(t, err) - assert.Equal(t, tt.expected, d.Expr.String()) - if tt.shortName != "" { - assert.Equal(t, tt.shortName, d.Expr.(*parser.Type).ShortType()) + assert.Equal(t, td.expected, d.Expr.String()) + if td.shortName != "" { + assert.Equal(t, td.shortName, d.Expr.(*parser.Type).ShortType()) typ, err := d.Expr.(*parser.Type).Type() assert.NoError(t, err) - assert.True(t, tt.typ.Equals(typ)) + assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(typ)) + assert.True(t, td.expectedTyp.Equals(typ)) } }) } diff --git a/types/types.go b/types/types.go index 7e2b8b0..6640efc 100644 --- a/types/types.go +++ b/types/types.go @@ -697,6 +697,7 @@ func (s *FixedLenType[T]) WithLength(length int32) FixedType { return &out } +// DecimalType is a decimal type with concrete precision and scale parameters, e.g. Decimal(10, 2). type DecimalType struct { Nullability Nullability TypeVariationRef uint32 From 2a81d6069a0a98afdba8011f4750aed1f057e9fd Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Tue, 3 Sep 2024 13:21:33 +0530 Subject: [PATCH 3/6] Address review comments --- types/any_type.go | 33 +++++------------------------ types/any_type_test.go | 1 - types/parameterized_decimal_type.go | 18 ++-------------- types/parameterized_types.go | 19 ++++------------- types/parameterized_types_test.go | 5 +---- types/parser/type_parser.go | 7 +++++- types/parser/type_parser_test.go | 10 +++++---- types/types.go | 13 ++++++++---- 8 files changed, 33 insertions(+), 73 deletions(-) diff --git a/types/any_type.go b/types/any_type.go index 2fbf352..43aad3b 100644 --- a/types/any_type.go +++ b/types/any_type.go @@ -2,15 +2,14 @@ package types import ( "fmt" - - "github.com/substrait-io/substrait-go/proto" ) // AnyType to represent AnyType, this type is to indicate "any" type of argument // This type is not used in function invocation. It is only used in function definition type AnyType struct { - Name string - Nullability Nullability + Name string + TypeVariationRef uint32 + Nullability Nullability } func (AnyType) isRootRef() {} @@ -22,37 +21,15 @@ func (m AnyType) GetType() Type { return m } func (m AnyType) GetNullability() Nullability { return m.Nullability } -func (AnyType) GetTypeVariationReference() uint32 { - panic("not allowed") +func (m AnyType) GetTypeVariationReference() uint32 { + return m.TypeVariationRef } func (AnyType) Equals(rhs Type) bool { // equal to every other type return true } -func (AnyType) ToProtoFuncArg() *proto.FunctionArgument { - panic("not allowed") -} - -func (AnyType) ToProto() *proto.Type { - panic("not allowed") -} - func (t AnyType) ShortString() string { return t.Name } func (t AnyType) String() string { return fmt.Sprintf("%s%s", t.Name, strNullable(t)) } - -// Below methods are for parser Def interface - -func (AnyType) Optional() bool { - panic("not allowed") -} - -func (m AnyType) ShortType() string { - return "any" -} - -func (m AnyType) Type() (Type, error) { - return m, nil -} diff --git a/types/any_type_test.go b/types/any_type_test.go index 8362617..53d9edb 100644 --- a/types/any_type_test.go +++ b/types/any_type_test.go @@ -27,7 +27,6 @@ func TestAnyType(t *testing.T) { require.Equal(t, td.expectedString, arg.String()) require.Equal(t, td.nullability, arg.GetNullability()) require.Equal(t, td.argName, arg.ShortString()) - require.Equal(t, "any", arg.ShortType()) }) } } diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go index d6946e6..a445fe1 100644 --- a/types/parameterized_decimal_type.go +++ b/types/parameterized_decimal_type.go @@ -2,8 +2,6 @@ package types import ( "fmt" - - "github.com/substrait-io/substrait-go/proto" ) // ParameterizedDecimalType is a decimal type with precision and scale parameters of string type @@ -34,25 +32,13 @@ func (m ParameterizedDecimalType) Equals(rhs Type) bool { return false } -func (ParameterizedDecimalType) ToProtoFuncArg() *proto.FunctionArgument { - // parameterized type are never on wire so to proto is not supported - panic("not supported") -} - func (m ParameterizedDecimalType) ShortString() string { t := DecimalType{} return t.ShortString() } func (m ParameterizedDecimalType) String() string { - return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) -} - -func (m ParameterizedDecimalType) ParameterString() string { - return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) -} - -func (m ParameterizedDecimalType) BaseString() string { t := DecimalType{} - return t.BaseString() + parameterString := fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) } diff --git a/types/parameterized_types.go b/types/parameterized_types.go index 5ebf146..21b0993 100644 --- a/types/parameterized_types.go +++ b/types/parameterized_types.go @@ -2,8 +2,6 @@ package types import ( "fmt" - - "github.com/substrait-io/substrait-go/proto" ) // IntegerParam represents a single integer parameter for a parameterized type @@ -16,10 +14,6 @@ func (m IntegerParam) Equals(o IntegerParam) bool { return m == o } -func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter { - panic("not implemented") -} - func (m IntegerParam) String() string { return m.Name } @@ -31,7 +25,7 @@ type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBi IntegerOption IntegerParam } -func (m ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType { +func (m ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) Type { m.IntegerOption = integerOption return m } @@ -54,11 +48,6 @@ func (m ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { return false } -func (ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument { - // parameterized type are never on wire so to proto is not supported - panic("not supported") -} - func (m ParameterizedTypeSingleIntegerParam[T]) ShortString() string { switch any(m).(type) { case ParameterizedVarCharType: @@ -82,14 +71,14 @@ func (m ParameterizedTypeSingleIntegerParam[T]) ShortString() string { } func (m ParameterizedTypeSingleIntegerParam[T]) String() string { - return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString()) + return fmt.Sprintf("%s%s%s", m.baseString(), strNullable(m), m.parameterString()) } -func (m ParameterizedTypeSingleIntegerParam[T]) ParameterString() string { +func (m ParameterizedTypeSingleIntegerParam[T]) parameterString() string { return fmt.Sprintf("<%s>", m.IntegerOption.String()) } -func (m ParameterizedTypeSingleIntegerParam[T]) BaseString() string { +func (m ParameterizedTypeSingleIntegerParam[T]) baseString() string { switch any(m).(type) { case ParameterizedVarCharType: t := VarCharType{} diff --git a/types/parameterized_types_test.go b/types/parameterized_types_test.go index c0d9845..e620e13 100644 --- a/types/parameterized_types_test.go +++ b/types/parameterized_types_test.go @@ -31,9 +31,6 @@ func TestParameterizedVarCharType(t *testing.T) { t.Run(td.name, func(t *testing.T) { pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability) require.Equal(t, td.expectedString, pt.String()) - parameterizeType, ok := pt.(types.ParameterizedType) - require.True(t, ok) - require.Equal(t, td.expectedBaseString, parameterizeType.BaseString()) require.Equal(t, td.expectedShortString, pt.ShortString()) require.True(t, pt.Equals(pt)) }) @@ -58,7 +55,7 @@ func TestParameterizedDecimalType(t *testing.T) { scale := types.IntegerParam{Name: td.scale} pt := types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability} require.Equal(t, td.expectedString, pt.String()) - require.Equal(t, td.expectedBaseString, pt.BaseString()) + //require.Equal(t, td.expectedBaseString, pt.BaseString()) require.Equal(t, td.expectedShortString, pt.ShortString()) require.True(t, pt.Equals(pt)) }) diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index a338e2b..5ac099a 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -287,6 +287,11 @@ func (d *decimalType) Optional() bool { return d.Nullability } func (d *decimalType) Type() (types.Type, error) { var n types.Nullability + if d.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } pi, ok1 := d.Precision.Expr.(*IntegerLiteral) si, ok2 := d.Scale.Expr.(*IntegerLiteral) if ok1 && ok2 { @@ -420,7 +425,7 @@ type anyType struct { Nullability bool `parser:"@'?'?"` } -func (anyType) Optional() bool { return false } +func (t anyType) Optional() bool { return t.Nullability } func (t anyType) String() string { opt := string(t.TypeName) diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 9cc6122..ac66753 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -24,9 +24,9 @@ func TestParser(t *testing.T) { {"i16?", "i16?", "i16", &types.Int16Type{Nullability: types.NullabilityNullable}}, {"boolean", "boolean", "bool", &types.BooleanType{Nullability: types.NullabilityRequired}}, {"fixedchar<5>", "fixedchar<5>", "fchar", &types.FixedCharType{Length: 5}}, - {"decimal<10,5>", "decimal<10, 5>", "dec", &types.DecimalType{Precision: 10, Scale: 5}}, - {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5}, Nullability: types.NullabilityRequired}}, - {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5}, Nullability: types.NullabilityNullable}}, + {"decimal<10,5>", "decimal<10, 5>", "dec", &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}}, + {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, {"struct", "struct", "struct", &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, @@ -37,7 +37,9 @@ func TestParser(t *testing.T) { {"fixedbinary", "fixedbinary", "fbin", types.ParameterizedFixedBinaryType{IntegerOption: types.IntegerParam{Name: "L1"}}}, {"precision_timestamp", "precision_timestamp", "prets", types.ParameterizedPrecisionTimestampType{IntegerOption: types.IntegerParam{Name: "L1"}}}, {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", types.ParameterizedPrecisionTimestampTzType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"decimal", "decimal", "dec", types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}}}, + {"decimal", "decimal", "dec", types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}, Nullability: types.NullabilityRequired}}, + {"any", "any", "any", types.AnyType{Nullability: types.NullabilityRequired}}, + {"any1?", "any1?", "any", types.AnyType{Nullability: types.NullabilityNullable}}, } p, err := parser.New() diff --git a/types/types.go b/types/types.go index 6640efc..278d9a9 100644 --- a/types/types.go +++ b/types/types.go @@ -185,7 +185,7 @@ type ( // TypeFromProto returns the appropriate Type object from a protobuf // type message. -func TypeFromProto(t *proto.Type) Type { +func TypeFromProto(t *proto.Type) FuncArgType { switch t := t.Kind.(type) { case *proto.Type_Bool: return &BooleanType{ @@ -355,10 +355,14 @@ type ( fmt.Stringer } + // FuncArgType this represents a type which can be a function argument + FuncArgType interface { + FuncArg + Type + } // Type corresponds to the proto.Type message and represents // a specific type. Type interface { - FuncArg isRootRef() fmt.Stringer ShortString() string @@ -371,6 +375,7 @@ type ( WithNullability(Nullability) Type } + // ParameterizedType this representa a concrete type with parameters ParameterizedType interface { Type ParameterString() string @@ -383,8 +388,8 @@ type ( } ParameterizedSingleIntegerType interface { - ParameterizedType - WithIntegerOption(param IntegerParam) ParameterizedSingleIntegerType + Type + WithIntegerOption(param IntegerParam) Type } ) From a6ef7c8e3da9c1a8e9848a046f564b1c73f53ea1 Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Tue, 10 Sep 2024 10:50:40 +0530 Subject: [PATCH 4/6] Address review comments --- expr/literals.go | 25 ++--- expr/string_test.go | 2 +- extensions/simple_extension.go | 16 +-- extensions/variants.go | 55 ++++++++++ extensions/variants_test.go | 41 +++++++ functions/types.go | 15 +-- literal/utils_test.go | 7 +- types/any_type_test.go | 10 +- .../abstract_parameter_type.go | 23 ++++ .../concrete_parameter_type.go | 23 ++++ types/parameter_types/leaf_parameter_type.go | 18 ++++ .../leaf_parameter_type_test.go | 34 ++++++ types/parameterized_decimal_type.go | 50 ++++++--- types/parameterized_decimal_type_test.go | 59 ++++++++++ types/parameterized_list_type.go | 59 ++++++++++ types/parameterized_list_type_test.go | 41 +++++++ types/parameterized_map_type.go | 65 +++++++++++ types/parameterized_map_type_test.go | 46 ++++++++ ...parameterized_single_integer_param_type.go | 101 ++++++++++++++++++ ...eterized_single_integer_param_type_test.go | 46 ++++++++ types/parameterized_struct_type.go | 80 ++++++++++++++ types/parameterized_struct_type_test.go | 45 ++++++++ types/parameterized_types.go | 101 ------------------ types/parameterized_types_test.go | 63 ----------- types/parser/type_parser.go | 81 +++++++++++--- types/parser/type_parser_test.go | 27 +++-- types/types.go | 85 ++++++++++----- types/types_test.go | 2 +- 28 files changed, 949 insertions(+), 271 deletions(-) create mode 100644 types/parameter_types/abstract_parameter_type.go create mode 100644 types/parameter_types/concrete_parameter_type.go create mode 100644 types/parameter_types/leaf_parameter_type.go create mode 100644 types/parameter_types/leaf_parameter_type_test.go create mode 100644 types/parameterized_decimal_type_test.go create mode 100644 types/parameterized_list_type.go create mode 100644 types/parameterized_list_type_test.go create mode 100644 types/parameterized_map_type.go create mode 100644 types/parameterized_map_type_test.go create mode 100644 types/parameterized_single_integer_param_type.go create mode 100644 types/parameterized_single_integer_param_type_test.go create mode 100644 types/parameterized_struct_type.go create mode 100644 types/parameterized_struct_type_test.go delete mode 100644 types/parameterized_types.go delete mode 100644 types/parameterized_types_test.go diff --git a/expr/literals.go b/expr/literals.go index dc8863e..bb45a79 100644 --- a/expr/literals.go +++ b/expr/literals.go @@ -12,6 +12,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/proto" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" "golang.org/x/exp/slices" "google.golang.org/protobuf/types/known/anypb" ) @@ -454,8 +455,8 @@ func (t *ProtoLiteral) ToProtoLiteral() *proto.Expression_Literal { lit.LiteralType = &proto.Expression_Literal_Decimal_{ Decimal: &proto.Expression_Literal_Decimal{ Value: v, - Precision: literalType.Precision, - Scale: literalType.Scale, + Precision: literalType.Precision.ToProtoVal(), + Scale: literalType.Scale.ToProtoVal(), }, } case *types.PrecisionTimestampType: @@ -527,7 +528,7 @@ func NewFixedCharLiteral(val types.FixedChar, nullable bool) *PrimitiveLiteral[t Value: val, Type: &types.FixedCharType{ Nullability: getNullability(nullable), - Length: int32(len(val)), + Length: parameter_types.LeafIntParamConcreteType(len(val)), }, } } @@ -614,7 +615,7 @@ func NewFixedBinaryLiteral(val types.FixedBinary, nullable bool) *ByteSliceLiter return &ByteSliceLiteral[types.FixedBinary]{ Value: val, Type: &types.FixedLenType[types.FixedBinary]{ - Length: int32(len(val)), + Length: parameter_types.LeafIntParamConcreteType(len(val)), Nullability: getNullability(nullable), }, } @@ -686,8 +687,8 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) { Value: v.Value, Type: &types.DecimalType{ Nullability: getNullability(nullable), - Precision: v.Precision, - Scale: v.Scale, + Precision: parameter_types.LeafIntParamConcreteType(v.Precision), + Scale: parameter_types.LeafIntParamConcreteType(v.Scale), }, }, nil case *types.UserDefinedLiteral: @@ -709,7 +710,7 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) { Value: v.Value, Type: &types.VarCharType{ Nullability: getNullability(nullable), - Length: int32(v.Length), + Length: parameter_types.LeafIntParamConcreteType(v.Length), }, }, nil case *types.PrecisionTimestamp: @@ -823,7 +824,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &PrimitiveLiteral[types.FixedChar]{ Value: types.FixedChar(lit.FixedChar), Type: &types.FixedCharType{ - Length: int32(len(lit.FixedChar)), + Length: parameter_types.LeafIntParamConcreteType(len(lit.FixedChar)), TypeVariationRef: l.TypeVariationReference, Nullability: nullability, }} @@ -831,7 +832,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ProtoLiteral{ Value: lit.VarChar.Value, Type: &types.VarCharType{ - Length: int32(lit.VarChar.Length), + Length: parameter_types.LeafIntParamConcreteType(lit.VarChar.Length), Nullability: nullability, TypeVariationRef: l.TypeVariationReference, }, @@ -840,7 +841,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ByteSliceLiteral[types.FixedBinary]{ Value: lit.FixedBinary, Type: &types.FixedBinaryType{ - Length: int32(len(lit.FixedBinary)), + Length: parameter_types.LeafIntParamConcreteType(len(lit.FixedBinary)), TypeVariationRef: l.TypeVariationReference, Nullability: nullability, }} @@ -848,8 +849,8 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ProtoLiteral{ Value: lit.Decimal.Value, Type: &types.DecimalType{ - Scale: lit.Decimal.Scale, - Precision: lit.Decimal.Precision, + Scale: parameter_types.LeafIntParamConcreteType(lit.Decimal.Scale), + Precision: parameter_types.LeafIntParamConcreteType(lit.Decimal.Precision), Nullability: nullability, TypeVariationRef: l.TypeVariationReference, }, diff --git a/expr/string_test.go b/expr/string_test.go index 8f1e63b..3fe2acf 100644 --- a/expr/string_test.go +++ b/expr/string_test.go @@ -35,7 +35,7 @@ func TestLiteralToString(t *testing.T) { Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false), }, }, true), - }, true), "list?>>([map?>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"}, + }, true), "list?>>([map?>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"}, {MustLiteral(expr.NewLiteral(float32(1.5), false)), "fp32(1.5)"}, {MustLiteral(expr.NewLiteral(&types.VarChar{Value: "foobar", Length: 7}, true)), "varchar?<7>(foobar)"}, {expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precisiontimestamp?<0>(123456)"}, diff --git a/extensions/simple_extension.go b/extensions/simple_extension.go index d629c1e..f055d73 100644 --- a/extensions/simple_extension.go +++ b/extensions/simple_extension.go @@ -3,13 +3,11 @@ package extensions import ( - "errors" "fmt" "reflect" "strings" substraitgo "github.com/substrait-io/substrait-go" - "github.com/substrait-io/substrait-go/types" "github.com/substrait-io/substrait-go/types/parser" ) @@ -59,7 +57,7 @@ type TypeVariation struct { type Argument interface { toTypeString() string - ArgType() (types.Type, error) + marker() // unexported marker method } type EnumArg struct { @@ -72,9 +70,7 @@ func (EnumArg) toTypeString() string { return "req" } -func (EnumArg) ArgType() (types.Type, error) { - return nil, errors.New("unimplemented") -} +func (v EnumArg) marker() {} type ValueArg struct { Name string `yaml:",omitempty"` @@ -87,9 +83,7 @@ func (v ValueArg) toTypeString() string { return v.Value.Expr.(*parser.Type).ShortType() } -func (v ValueArg) ArgType() (types.Type, error) { - return v.Value.Expr.(*parser.Type).Type() -} +func (v ValueArg) marker() {} type TypeArg struct { Name string `yaml:",omitempty"` @@ -99,9 +93,7 @@ type TypeArg struct { func (TypeArg) toTypeString() string { return "type" } -func (TypeArg) ArgType() (types.Type, error) { - return nil, errors.New("unimplemented") -} +func (v TypeArg) marker() {} type ArgumentList []Argument diff --git a/extensions/variants.go b/extensions/variants.go index 5619be8..2a14d54 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -8,6 +8,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" "github.com/substrait-io/substrait-go/types/parser" ) @@ -375,3 +376,57 @@ func (s *WindowFunctionVariant) Intermediate() (types.Type, error) { func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered } func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet } func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType } + +// HasSyncParams This API returns if params share a leaf param name +func HasSyncParams(params []types.Type) bool { + // get list of parameters from Abstract parameter type + // if any of the parameter is common, it indicates parameters are same across parameters + existingParamMap := make(map[string]bool) + for _, p := range params { + pat, ok := p.(types.ParameterizedAbstractType) + if !ok { + // not a type which contains abstract parameters, so continue + continue + } + // get list of parameters for each abstract parameter type + // note, this can be more than one parameter because of nested abstract types + // e.g. Decimal or List, VARCHAR>> + abstractParams := pat.GetAbstractParameters() + var leafParams []string + for _, abstractParam := range abstractParams { + leafParams = append(leafParams, getLeafAbstractParams(abstractParam)...) + } + // all leaf params for this parameters are found + // if map contains any of the leaf params, parameters are synced + for _, leafParam := range leafParams { + if _, ok := existingParamMap[leafParam]; ok { + return true + } + } + // add all params to map, kindly note we can't add these params + // in previous loop to avoid having same leaf abstract type in same param + // e.g. Decimal has no sync param + for _, leafParam := range leafParams { + existingParamMap[leafParam] = true + } + } + return false +} + +// from a parameter of abstract type, get the leaf parameters +// an abstract parameter can be a leaf type or a parameterized type itself +// if it is a leaf type, its param name is returned +// if it is parameterized type, leaf type is found recursively +func getLeafAbstractParams(abstractTypes parameter_types.AbstractParameterType) []string { + // if it is not a leaf type recurse + if pat, ok := abstractTypes.(types.ParameterizedAbstractType); ok { + var outLeafParams []string + for _, p := range pat.GetAbstractParameters() { + childLeafParams := getLeafAbstractParams(p) + outLeafParams = append(outLeafParams, childLeafParams...) + } + return outLeafParams + } + // for leaf type, return the param name + return []string{abstractTypes.GetAbstractParamName()} +} diff --git a/extensions/variants_test.go b/extensions/variants_test.go index dee19e2..9477b47 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" "github.com/substrait-io/substrait-go/types/parser" ) @@ -65,3 +67,42 @@ func TestEvaluateTypeExpression(t *testing.T) { }) } } + +func TestHasSyncParams(t *testing.T) { + + apt_P := parameter_types.LeafIntParamAbstractType("P") + apt_Q := parameter_types.LeafIntParamAbstractType("Q") + cpt_38 := parameter_types.LeafIntParamConcreteType(38) + + fct_P := &types.ParameterizedFixedCharType{IntegerOption: apt_P} + fct_Q := &types.ParameterizedFixedCharType{IntegerOption: apt_Q} + decimal_PQ := &types.ParameterizedDecimalType{Precision: apt_P, Scale: apt_Q} + decimal_38_Q := &types.ParameterizedDecimalType{Precision: cpt_38, Scale: apt_Q} + list_decimal_38_Q := &types.ParameterizedListType{Type: decimal_38_Q} + map_fctQ_decimal38Q := &types.ParameterizedMapType{Key: fct_Q, Value: decimal_38_Q} + struct_fctQ_ListDecimal38Q := &types.ParameterizedStructType{Type: []types.Type{fct_Q, list_decimal_38_Q}} + for _, td := range []struct { + name string + params []types.Type + expectedHasSyncParams bool + }{ + {"No Abstract Type", []types.Type{&types.Int64Type{}}, false}, + {"No Sync Param P, Q", []types.Type{fct_P, fct_Q}, false}, + {"Sync Params P, P", []types.Type{fct_P, fct_P}, true}, + {"Sync Params P, ", []types.Type{fct_P, decimal_PQ}, true}, + {"No Sync Params P, <38, Q>", []types.Type{fct_P, decimal_38_Q}, false}, + {"Sync Params P, List>", []types.Type{fct_P, list_decimal_38_Q}, false}, + {"No Sync Params fct

, Map, decimal<38,Q>>", []types.Type{fct_P, map_fctQ_decimal38Q}, false}, + {"Sync Params fct, Map, decimal<38,Q>>", []types.Type{fct_Q, map_fctQ_decimal38Q}, true}, + {"No Sync Params fct

, struct, list<38,Q>>", []types.Type{fct_P, struct_fctQ_ListDecimal38Q}, false}, + {"Sync Params fct, struct, list<38,Q>>", []types.Type{fct_Q, struct_fctQ_ListDecimal38Q}, true}, + } { + t.Run(td.name, func(t *testing.T) { + if td.expectedHasSyncParams { + require.True(t, extensions.HasSyncParams(td.params)) + } else { + require.False(t, extensions.HasSyncParams(td.params)) + } + }) + } +} diff --git a/functions/types.go b/functions/types.go index e61e289..c4a6266 100644 --- a/functions/types.go +++ b/functions/types.go @@ -6,6 +6,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" ) var ( @@ -83,18 +84,18 @@ func getTypeWithParameters(typ types.Type, parameters []int32) (types.Type, erro if len(parameters) != 2 { return nil, substraitgo.ErrInvalidType } - return &types.DecimalType{Precision: parameters[0], Scale: parameters[1]}, nil + return &types.DecimalType{Precision: parameter_types.LeafIntParamConcreteType(parameters[0]), Scale: parameter_types.LeafIntParamConcreteType(parameters[1])}, nil case *types.FixedBinaryType, *types.FixedCharType, *types.VarCharType: if len(parameters) != 1 { return nil, substraitgo.ErrInvalidType } switch typ.(type) { case *types.FixedBinaryType: - return &types.FixedBinaryType{Length: parameters[0]}, nil + return &types.FixedBinaryType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil case *types.FixedCharType: - return &types.FixedCharType{Length: parameters[0]}, nil + return &types.FixedCharType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil case *types.VarCharType: - return &types.VarCharType{Length: parameters[0]}, nil + return &types.VarCharType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil } default: if len(parameters) != 0 { @@ -147,14 +148,14 @@ type typeInfo struct { func (ti *typeInfo) getLongName() string { switch ti.typ.(type) { - case types.ParameterizedType: - return ti.typ.(types.ParameterizedType).BaseString() + case types.ParameterizedConcreteType: + return ti.typ.(types.ParameterizedConcreteType).BaseString() } return ti.typ.String() } func (ti *typeInfo) getLocalTypeString(input types.Type, enclosure typeEnclosure) string { - if paramType, ok := input.(types.ParameterizedType); ok { + if paramType, ok := input.(types.ParameterizedConcreteType); ok { return ti.localName + enclosure.containerStart() + paramType.ParameterString() + enclosure.containerEnd() } return ti.localName diff --git a/literal/utils_test.go b/literal/utils_test.go index c6848c7..9a3fefc 100644 --- a/literal/utils_test.go +++ b/literal/utils_test.go @@ -12,6 +12,7 @@ import ( "github.com/substrait-io/substrait-go/expr" "github.com/substrait-io/substrait-go/proto" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" ) func TestNewBool(t *testing.T) { @@ -89,8 +90,8 @@ func createDecimalLiteral(value []byte, precision int32, scale int32, isNullable Value: value[:16], Type: &types.DecimalType{ Nullability: nullability, - Precision: precision, - Scale: scale, + Precision: parameter_types.LeafIntParamConcreteType(precision), + Scale: parameter_types.LeafIntParamConcreteType(scale), }, } } @@ -674,7 +675,7 @@ func createVarCharLiteral(value string) *expr.ProtoLiteral { Value: value, Type: &types.VarCharType{ Nullability: proto.Type_NULLABILITY_REQUIRED, - Length: int32(len(value)), + Length: parameter_types.LeafIntParamConcreteType(len(value)), }, } } diff --git a/types/any_type_test.go b/types/any_type_test.go index 53d9edb..dec44cb 100644 --- a/types/any_type_test.go +++ b/types/any_type_test.go @@ -24,9 +24,13 @@ func TestAnyType(t *testing.T) { Name: td.argName, Nullability: td.nullability, } - require.Equal(t, td.expectedString, arg.String()) - require.Equal(t, td.nullability, arg.GetNullability()) - require.Equal(t, td.argName, arg.ShortString()) + anyType := arg.WithNullability(td.nullability) + require.Equal(t, td.expectedString, anyType.String()) + require.Equal(t, td.nullability, anyType.GetNullability()) + require.Equal(t, td.argName, anyType.ShortString()) + // any type should be equal to any other type including itself + require.True(t, anyType.Equals(anyType)) + require.True(t, anyType.Equals(&types.Int8Type{})) }) } } diff --git a/types/parameter_types/abstract_parameter_type.go b/types/parameter_types/abstract_parameter_type.go new file mode 100644 index 0000000..3e056c5 --- /dev/null +++ b/types/parameter_types/abstract_parameter_type.go @@ -0,0 +1,23 @@ +package parameter_types + +// LeafIntParamAbstractType represents an integer parameter for a parameterized type +// Example: VARCHAR(L1) -> L1 is an LeafIntParamAbstractType +// DECIMAL --> P Is an LeafIntParamAbstractType +type LeafIntParamAbstractType string + +func (m LeafIntParamAbstractType) IsCompatible(o LeafParameter) bool { + switch o.(type) { + case LeafIntParamAbstractType, LeafIntParamConcreteType: + return true + default: + return false + } +} + +func (m LeafIntParamAbstractType) String() string { + return string(m) +} + +func (m LeafIntParamAbstractType) GetAbstractParamName() string { + return string(m) +} diff --git a/types/parameter_types/concrete_parameter_type.go b/types/parameter_types/concrete_parameter_type.go new file mode 100644 index 0000000..9ea6603 --- /dev/null +++ b/types/parameter_types/concrete_parameter_type.go @@ -0,0 +1,23 @@ +package parameter_types + +import "fmt" + +// LeafIntParamConcreteType represents a single integer concrete parameter for a concrete type +// Example: VARCHAR(6) -> 6 is an LeafIntParamConcreteType +// DECIMAL --> 0 Is an LeafIntParamConcreteType but P not +type LeafIntParamConcreteType int32 + +func (m LeafIntParamConcreteType) IsCompatible(o LeafParameter) bool { + if t, ok := o.(LeafIntParamConcreteType); ok { + return t == m + } + return false +} + +func (m LeafIntParamConcreteType) String() string { + return fmt.Sprintf("%d", m) +} + +func (m LeafIntParamConcreteType) ToProtoVal() int32 { + return int32(m) +} diff --git a/types/parameter_types/leaf_parameter_type.go b/types/parameter_types/leaf_parameter_type.go new file mode 100644 index 0000000..d6e4504 --- /dev/null +++ b/types/parameter_types/leaf_parameter_type.go @@ -0,0 +1,18 @@ +package parameter_types + +// LeafParameter represents a parameter type +// parameter can of concrete (38) or abstract type (P) +// or another parameterized type like VARCHAR<"L1"> +type LeafParameter interface { + // IsCompatible is type compatible with other + // compatible is other can be used in place of this type + IsCompatible(other LeafParameter) bool + String() string +} + +// AbstractParameterType represents a parameter type which is abstract. +// it can be a leaf parameter (LeafIntParamAbstractType) +// or another abstract type like "DECIMAL" +type AbstractParameterType interface { + GetAbstractParamName() string +} diff --git a/types/parameter_types/leaf_parameter_type_test.go b/types/parameter_types/leaf_parameter_type_test.go new file mode 100644 index 0000000..4ef799c --- /dev/null +++ b/types/parameter_types/leaf_parameter_type_test.go @@ -0,0 +1,34 @@ +package parameter_types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +func TestConcreteParameterType(t *testing.T) { + concreteType1 := parameter_types.LeafIntParamConcreteType(1) + require.Equal(t, "1", concreteType1.String()) +} + +func TestLeafParameterType(t *testing.T) { + var concreteType1, concreteType2, abstractType1 parameter_types.LeafParameter + + concreteType1 = parameter_types.LeafIntParamConcreteType(1) + concreteType2 = parameter_types.LeafIntParamConcreteType(2) + + abstractType1 = parameter_types.LeafIntParamAbstractType("P") + + // verify string val + require.Equal(t, "1", concreteType1.String()) + require.Equal(t, "P", abstractType1.String()) + + // concrete type is only compatible with same type + require.True(t, concreteType1.IsCompatible(concreteType1)) + require.False(t, concreteType1.IsCompatible(concreteType2)) + + // abstract type is compatible with both abstract and concrete type + require.True(t, abstractType1.IsCompatible(abstractType1)) + require.True(t, abstractType1.IsCompatible(concreteType2)) +} diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go index a445fe1..05897e4 100644 --- a/types/parameterized_decimal_type.go +++ b/types/parameterized_decimal_type.go @@ -2,43 +2,65 @@ package types import ( "fmt" + + "github.com/substrait-io/substrait-go/types/parameter_types" ) -// ParameterizedDecimalType is a decimal type with precision and scale parameters of string type -// example: Decimal(P,S). Kindly note concrete types Decimal(10, 2) are not represented by this type +// ParameterizedDecimalType is a decimal type with at least one of precision and scale parameters of string type +// example: Decimal or Decimal. +// Note concrete types e.g. Decimal(10, 2) are not represented by this type // Concrete type is represented by DecimalType type ParameterizedDecimalType struct { Nullability Nullability TypeVariationRef uint32 - Precision IntegerParam - Scale IntegerParam + Precision parameter_types.LeafParameter + Scale parameter_types.LeafParameter } -func (ParameterizedDecimalType) isRootRef() {} -func (m ParameterizedDecimalType) WithNullability(n Nullability) Type { +func (*ParameterizedDecimalType) isRootRef() {} +func (m *ParameterizedDecimalType) WithNullability(n Nullability) Type { m.Nullability = n return m } -func (m ParameterizedDecimalType) GetType() Type { return m } -func (m ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } -func (m ParameterizedDecimalType) GetTypeVariationReference() uint32 { +func (m *ParameterizedDecimalType) GetType() Type { return m } +func (m *ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedDecimalType) GetTypeVariationReference() uint32 { return m.TypeVariationRef } -func (m ParameterizedDecimalType) Equals(rhs Type) bool { - if o, ok := rhs.(ParameterizedDecimalType); ok { - return o == m +func (m *ParameterizedDecimalType) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedDecimalType); ok { + return *o == *m } return false } -func (m ParameterizedDecimalType) ShortString() string { +func (m *ParameterizedDecimalType) ShortString() string { t := DecimalType{} return t.ShortString() } -func (m ParameterizedDecimalType) String() string { +func (m *ParameterizedDecimalType) String() string { t := DecimalType{} parameterString := fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) } + +// GetAbstractParameters returns the abstract parameter names +// this implements interface ParameterizedAbstractType +func (m *ParameterizedDecimalType) GetAbstractParameters() []parameter_types.AbstractParameterType { + var params []parameter_types.AbstractParameterType + if p, ok := m.Precision.(parameter_types.AbstractParameterType); ok { + params = append(params, p) + } + if p, ok := m.Scale.(parameter_types.AbstractParameterType); ok { + params = append(params, p) + } + return params +} + +// GetAbstractParamName this implements interface AbstractParameterType +// to indicate ParameterizedDecimalType itself can be used as a parameter of abstract type too +func (m *ParameterizedDecimalType) GetAbstractParamName() string { + return m.String() +} diff --git a/types/parameterized_decimal_type_test.go b/types/parameterized_decimal_type_test.go new file mode 100644 index 0000000..87990af --- /dev/null +++ b/types/parameterized_decimal_type_test.go @@ -0,0 +1,59 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +func TestParameterizedDecimalType(t *testing.T) { + for _, td := range []struct { + name string + precision string + scale string + nullability types.Nullability + expectedString string + expectedShortString string + }{ + {"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?", "dec"}, + {"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal", "dec"}, + } { + t.Run(td.name, func(t *testing.T) { + precision := parameter_types.LeafIntParamAbstractType(td.precision) + scale := parameter_types.LeafIntParamAbstractType(td.scale) + pd := &types.ParameterizedDecimalType{Precision: precision, Scale: scale} + pdType := pd.WithNullability(td.nullability) + require.Equal(t, td.expectedString, pdType.String()) + require.Equal(t, td.expectedShortString, pdType.ShortString()) + require.True(t, pdType.Equals(pdType)) + + pdAbsParamType, ok := pdType.(parameter_types.AbstractParameterType) + require.True(t, ok) + require.Equal(t, td.expectedString, pdAbsParamType.GetAbstractParamName()) + pdAbstractType, ok := pdType.(types.ParameterizedAbstractType) + require.True(t, ok) + require.Len(t, pdAbstractType.GetAbstractParameters(), 2) + }) + } +} + +func TestParameterizedDecimalSingleAbstractParam(t *testing.T) { + precision := parameter_types.LeafIntParamConcreteType(38) + scale := parameter_types.LeafIntParamAbstractType("S") + + pd := &types.ParameterizedDecimalType{Precision: precision, Scale: scale} + pdType := pd.WithNullability(types.NullabilityNullable) + require.Equal(t, "decimal?<38,S>", pdType.String()) + require.Equal(t, "dec", pdType.ShortString()) + require.True(t, pdType.Equals(pdType)) + + pdAbsParamType, ok := pdType.(parameter_types.AbstractParameterType) + require.True(t, ok) + require.Equal(t, "decimal?<38,S>", pdAbsParamType.GetAbstractParamName()) + pdAbstractType, ok := pdType.(types.ParameterizedAbstractType) + require.True(t, ok) + // only one abstract param + require.Len(t, pdAbstractType.GetAbstractParameters(), 1) +} diff --git a/types/parameterized_list_type.go b/types/parameterized_list_type.go new file mode 100644 index 0000000..10a13b0 --- /dev/null +++ b/types/parameterized_list_type.go @@ -0,0 +1,59 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +// ParameterizedListType is a list type having parameter of ParameterizedAbstractType +// basically a list of which type is another abstract parameter +// example: List. Kindly note concrete types List is not represented by this type +// Concrete type is represented by ListType +type ParameterizedListType struct { + Nullability Nullability + TypeVariationRef uint32 + Type ParameterizedAbstractType +} + +func (*ParameterizedListType) isRootRef() {} +func (m *ParameterizedListType) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m *ParameterizedListType) GetType() Type { return m } +func (m *ParameterizedListType) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedListType) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m *ParameterizedListType) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedListType); ok { + return m.Nullability == o.Nullability && m.TypeVariationRef == o.TypeVariationRef && + m.Type.Equals(o.Type) + } + return false +} + +func (m *ParameterizedListType) ShortString() string { + t := ListType{} + return t.ShortString() +} + +func (m *ParameterizedListType) String() string { + t := ListType{} + parameterString := fmt.Sprintf("<%s>", m.Type) + return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) +} + +// GetAbstractParameters returns the abstract parameter names +// this implements interface ParameterizedAbstractType +func (m *ParameterizedListType) GetAbstractParameters() []parameter_types.AbstractParameterType { + return []parameter_types.AbstractParameterType{m.Type.(parameter_types.AbstractParameterType)} +} + +// GetAbstractParamName this implements interface AbstractParameterType +// to indicate ParameterizedListType itself can be used as a parameter of abstract type too +func (m *ParameterizedListType) GetAbstractParamName() string { + return m.String() +} diff --git a/types/parameterized_list_type_test.go b/types/parameterized_list_type_test.go new file mode 100644 index 0000000..fc0a7b5 --- /dev/null +++ b/types/parameterized_list_type_test.go @@ -0,0 +1,41 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +func TestParameterizedListType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: parameter_types.LeafIntParamAbstractType("P"), + Scale: parameter_types.LeafIntParamAbstractType("S"), + Nullability: types.NullabilityRequired, + } + for _, td := range []struct { + name string + typ types.ParameterizedAbstractType + nullability types.Nullability + expectedString string + expectedShortString string + }{ + {"list", decimalType, types.NullabilityNullable, "list?>", "list"}, + } { + t.Run(td.name, func(t *testing.T) { + pl := &types.ParameterizedListType{Type: td.typ} + plType := pl.WithNullability(td.nullability) + require.Equal(t, td.expectedString, plType.String()) + require.Equal(t, td.expectedShortString, plType.ShortString()) + require.True(t, plType.Equals(plType)) + + plAbsParamType, ok := plType.(parameter_types.AbstractParameterType) + require.True(t, ok) + require.Equal(t, td.expectedString, plAbsParamType.GetAbstractParamName()) + plAbstractType, ok := plType.(types.ParameterizedAbstractType) + require.True(t, ok) + require.Len(t, plAbstractType.GetAbstractParameters(), 1) + }) + } +} diff --git a/types/parameterized_map_type.go b/types/parameterized_map_type.go new file mode 100644 index 0000000..93c32fa --- /dev/null +++ b/types/parameterized_map_type.go @@ -0,0 +1,65 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +// ParameterizedMapType is a struct having at least one of key or value of type ParameterizedAbstractType +// If All arguments are concrete they are represented by MapType +type ParameterizedMapType struct { + Nullability Nullability + TypeVariationRef uint32 + Key Type + Value Type +} + +func (*ParameterizedMapType) isRootRef() {} +func (m *ParameterizedMapType) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m *ParameterizedMapType) GetType() Type { return m } +func (m *ParameterizedMapType) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedMapType) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m *ParameterizedMapType) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedMapType); ok { + return m.Nullability == o.Nullability && m.TypeVariationRef == o.TypeVariationRef && + m.Key.Equals(o.Key) && m.Value.Equals(o.Value) + } + return false +} + +func (m *ParameterizedMapType) ShortString() string { + t := MapType{} + return t.ShortString() +} + +func (m *ParameterizedMapType) String() string { + t := MapType{} + parameterString := fmt.Sprintf("<%s, %s>", m.Key.String(), m.Value.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) +} + +// GetAbstractParameters returns the abstract parameter names +// this implements interface ParameterizedAbstractType +func (m *ParameterizedMapType) GetAbstractParameters() []parameter_types.AbstractParameterType { + var abstractParams []parameter_types.AbstractParameterType + if abs, ok := m.Key.(parameter_types.AbstractParameterType); ok { + abstractParams = append(abstractParams, abs) + } + if abs, ok := m.Value.(parameter_types.AbstractParameterType); ok { + abstractParams = append(abstractParams, abs) + } + return abstractParams +} + +// GetAbstractParamName this implements interface AbstractParameterType +// to indicate ParameterizedStructType itself can be used as a parameter of abstract type too +func (m *ParameterizedMapType) GetAbstractParamName() string { + return m.String() +} diff --git a/types/parameterized_map_type_test.go b/types/parameterized_map_type_test.go new file mode 100644 index 0000000..d882186 --- /dev/null +++ b/types/parameterized_map_type_test.go @@ -0,0 +1,46 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +func TestParameterizedMapType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: parameter_types.LeafIntParamAbstractType("P"), + Scale: parameter_types.LeafIntParamAbstractType("S"), + Nullability: types.NullabilityRequired, + } + int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} + listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} + for _, td := range []struct { + name string + Key types.Type + Value types.Type + nullability types.Nullability + expectedString string + expectedShortString string + expectedNrAbstractParam int + }{ + {"single abstract param", decimalType, int8Type, types.NullabilityNullable, "map?, i8?>", "map", 1}, + {"both abstract param", decimalType, listType, types.NullabilityNullable, "map?, list?>>", "map", 2}, + } { + t.Run(td.name, func(t *testing.T) { + pm := &types.ParameterizedMapType{Key: td.Key, Value: td.Value} + pmType := pm.WithNullability(td.nullability) + require.Equal(t, td.expectedString, pmType.String()) + require.Equal(t, td.expectedShortString, pmType.ShortString()) + require.True(t, pmType.Equals(pmType)) + + pmAbsParamType, ok := pmType.(parameter_types.AbstractParameterType) + require.True(t, ok) + require.Equal(t, td.expectedString, pmAbsParamType.GetAbstractParamName()) + pmAbstractType, ok := pmType.(types.ParameterizedAbstractType) + require.True(t, ok) + require.Len(t, pmAbstractType.GetAbstractParameters(), td.expectedNrAbstractParam) + }) + } +} diff --git a/types/parameterized_single_integer_param_type.go b/types/parameterized_single_integer_param_type.go new file mode 100644 index 0000000..a417b1f --- /dev/null +++ b/types/parameterized_single_integer_param_type.go @@ -0,0 +1,101 @@ +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +// parameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter +type parameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { + Nullability Nullability + TypeVariationRef uint32 + IntegerOption parameter_types.LeafIntParamAbstractType +} + +func (m parameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption parameter_types.LeafIntParamAbstractType) Type { + m.IntegerOption = integerOption + return m +} + +func (parameterizedTypeSingleIntegerParam[T]) isRootRef() {} +func (m parameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m parameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } +func (m parameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } +func (m parameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m parameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { + if o, ok := rhs.(parameterizedTypeSingleIntegerParam[T]); ok { + return o == m + } + return false +} + +func (m parameterizedTypeSingleIntegerParam[T]) ShortString() string { + switch any(m).(type) { + case ParameterizedVarCharType: + t := VarCharType{} + return t.ShortString() + case ParameterizedFixedCharType: + t := FixedCharType{} + return t.ShortString() + case ParameterizedFixedBinaryType: + t := FixedBinaryType{} + return t.ShortString() + case ParameterizedPrecisionTimestampType: + t := PrecisionTimestampType{} + return t.ShortString() + case ParameterizedPrecisionTimestampTzType: + t := PrecisionTimestampTzType{} + return t.ShortString() + default: + panic("unknown type") + } +} + +func (m parameterizedTypeSingleIntegerParam[T]) String() string { + return fmt.Sprintf("%s%s%s", m.baseString(), strNullable(m), m.parameterString()) +} + +func (m parameterizedTypeSingleIntegerParam[T]) parameterString() string { + return fmt.Sprintf("<%s>", m.IntegerOption.GetAbstractParamName()) +} + +func (m parameterizedTypeSingleIntegerParam[T]) baseString() string { + switch any(m).(type) { + case ParameterizedVarCharType: + t := VarCharType{} + return t.BaseString() + case ParameterizedFixedCharType: + t := FixedCharType{} + return t.BaseString() + case ParameterizedFixedBinaryType: + t := FixedBinaryType{} + return t.BaseString() + case ParameterizedPrecisionTimestampType: + t := PrecisionTimestampType{} + return t.BaseString() + case ParameterizedPrecisionTimestampTzType: + t := PrecisionTimestampTzType{} + return t.BaseString() + default: + panic("unknown type") + } +} + +// GetAbstractParameters returns the abstract parameter names +// this implements interface ParameterizedAbstractType +func (m parameterizedTypeSingleIntegerParam[T]) GetAbstractParameters() []parameter_types.AbstractParameterType { + return []parameter_types.AbstractParameterType{m.IntegerOption} +} + +// GetAbstractParamName this implements interface AbstractParameterType +// basically, this type itself can be used as a parameter of abstract type too +func (m parameterizedTypeSingleIntegerParam[T]) GetAbstractParamName() string { + return m.String() +} diff --git a/types/parameterized_single_integer_param_type_test.go b/types/parameterized_single_integer_param_type_test.go new file mode 100644 index 0000000..2412c0c --- /dev/null +++ b/types/parameterized_single_integer_param_type_test.go @@ -0,0 +1,46 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +// a type to indicate all single integer type. +// helpful in initializing different single type integer type to the same interface +type parameterizedSingleIntegerType interface { + types.Type + WithIntegerOption(param parameter_types.LeafIntParamAbstractType) types.Type +} + +func TestParameterizedSingleIntegerType(t *testing.T) { + for _, td := range []struct { + name string + typ parameterizedSingleIntegerType + nullability types.Nullability + integerOption parameter_types.LeafIntParamAbstractType + expectedString string + expectedBaseString string + expectedShortString string + }{ + {"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "varchar?", "varchar", "vchar"}, + {"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "varchar", "varchar", "vchar"}, + {"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "char?", "char", "fchar"}, + {"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "char", "char", "fchar"}, + {"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "fixedbinary?", "fixedbinary", "fbin"}, + {"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "fixedbinary", "fixedbinary", "fbin"}, + {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp?", "precision_timestamp", "prets"}, + {"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp", "precision_timestamp", "prets"}, + {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp_tz?", "precision_timestamp_tz", "pretstz"}, + {"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp_tz", "precision_timestamp_tz", "pretstz"}, + } { + t.Run(td.name, func(t *testing.T) { + pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability) + require.Equal(t, td.expectedString, pt.String()) + require.Equal(t, td.expectedShortString, pt.ShortString()) + require.True(t, pt.Equals(pt)) + }) + } +} diff --git a/types/parameterized_struct_type.go b/types/parameterized_struct_type.go new file mode 100644 index 0000000..6c66efc --- /dev/null +++ b/types/parameterized_struct_type.go @@ -0,0 +1,80 @@ +package types + +import ( + "fmt" + "strings" + + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +// ParameterizedStructType is a struct having at least one parameter of type ParameterizedAbstractType +// example: Struct. +// If All arguments are concrete they are represented by StructType +type ParameterizedStructType struct { + Nullability Nullability + TypeVariationRef uint32 + Type []Type +} + +func (*ParameterizedStructType) isRootRef() {} +func (m *ParameterizedStructType) WithNullability(n Nullability) Type { + m.Nullability = n + return m +} + +func (m *ParameterizedStructType) GetType() Type { return m } +func (m *ParameterizedStructType) GetNullability() Nullability { return m.Nullability } +func (m *ParameterizedStructType) GetTypeVariationReference() uint32 { + return m.TypeVariationRef +} +func (m *ParameterizedStructType) Equals(rhs Type) bool { + if o, ok := rhs.(*ParameterizedStructType); ok { + if m.Nullability != o.Nullability || len(m.Type) != len(o.Type) || + m.TypeVariationRef != o.TypeVariationRef { + return false + } + for i := range m.Type { + if !m.Type[i].Equals(o.Type[i]) { + return false + } + } + return true + } + return false +} + +func (m *ParameterizedStructType) ShortString() string { + t := StructType{} + return t.ShortString() +} + +func (m *ParameterizedStructType) String() string { + sb := strings.Builder{} + for i, typ := range m.Type { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(typ.String()) + } + t := StructType{} + parameterString := fmt.Sprintf("<%s>", sb.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) +} + +// GetAbstractParameters returns the abstract parameter names +// this implements interface ParameterizedAbstractType +func (m *ParameterizedStructType) GetAbstractParameters() []parameter_types.AbstractParameterType { + var abstractParams []parameter_types.AbstractParameterType + for _, typ := range m.Type { + if abs, ok := typ.(parameter_types.AbstractParameterType); ok { + abstractParams = append(abstractParams, abs) + } + } + return abstractParams +} + +// GetAbstractParamName this implements interface AbstractParameterType +// to indicate ParameterizedStructType itself can be used as a parameter of abstract type too +func (m *ParameterizedStructType) GetAbstractParamName() string { + return m.String() +} diff --git a/types/parameterized_struct_type_test.go b/types/parameterized_struct_type_test.go new file mode 100644 index 0000000..0e70b6f --- /dev/null +++ b/types/parameterized_struct_type_test.go @@ -0,0 +1,45 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" +) + +func TestParameterizedStructType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: parameter_types.LeafIntParamAbstractType("P"), + Scale: parameter_types.LeafIntParamAbstractType("S"), + Nullability: types.NullabilityRequired, + } + int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} + listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} + for _, td := range []struct { + name string + types []types.Type + nullability types.Nullability + expectedString string + expectedShortString string + expectedNrAbstractParam int + }{ + {"single abstract param", []types.Type{decimalType}, types.NullabilityNullable, "struct?>", "struct", 1}, + {"multiple abstract param", []types.Type{decimalType, int8Type, listType}, types.NullabilityRequired, "struct, i8?, list?>>", "struct", 2}, + } { + t.Run(td.name, func(t *testing.T) { + ps := &types.ParameterizedStructType{Type: td.types} + psType := ps.WithNullability(td.nullability) + require.Equal(t, td.expectedString, psType.String()) + require.Equal(t, td.expectedShortString, psType.ShortString()) + require.True(t, psType.Equals(psType)) + + psAbsParamType, ok := psType.(parameter_types.AbstractParameterType) + require.True(t, ok) + require.Equal(t, td.expectedString, psAbsParamType.GetAbstractParamName()) + psAbstractType, ok := psType.(types.ParameterizedAbstractType) + require.True(t, ok) + require.Len(t, psAbstractType.GetAbstractParameters(), td.expectedNrAbstractParam) + }) + } +} diff --git a/types/parameterized_types.go b/types/parameterized_types.go deleted file mode 100644 index 21b0993..0000000 --- a/types/parameterized_types.go +++ /dev/null @@ -1,101 +0,0 @@ -package types - -import ( - "fmt" -) - -// IntegerParam represents a single integer parameter for a parameterized type -// Example: VARCHAR(L1) -> L1 is the integer parameter -type IntegerParam struct { - Name string -} - -func (m IntegerParam) Equals(o IntegerParam) bool { - return m == o -} - -func (m IntegerParam) String() string { - return m.Name -} - -// ParameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter -type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { - Nullability Nullability - TypeVariationRef uint32 - IntegerOption IntegerParam -} - -func (m ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) Type { - m.IntegerOption = integerOption - return m -} - -func (ParameterizedTypeSingleIntegerParam[T]) isRootRef() {} -func (m ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { - m.Nullability = n - return m -} - -func (m ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } -func (m ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } -func (m ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { - if o, ok := rhs.(ParameterizedTypeSingleIntegerParam[T]); ok { - return o == m - } - return false -} - -func (m ParameterizedTypeSingleIntegerParam[T]) ShortString() string { - switch any(m).(type) { - case ParameterizedVarCharType: - t := VarCharType{} - return t.ShortString() - case ParameterizedFixedCharType: - t := FixedCharType{} - return t.ShortString() - case ParameterizedFixedBinaryType: - t := FixedBinaryType{} - return t.ShortString() - case ParameterizedPrecisionTimestampType: - t := PrecisionTimestampType{} - return t.ShortString() - case ParameterizedPrecisionTimestampTzType: - t := PrecisionTimestampTzType{} - return t.ShortString() - default: - panic("unknown type") - } -} - -func (m ParameterizedTypeSingleIntegerParam[T]) String() string { - return fmt.Sprintf("%s%s%s", m.baseString(), strNullable(m), m.parameterString()) -} - -func (m ParameterizedTypeSingleIntegerParam[T]) parameterString() string { - return fmt.Sprintf("<%s>", m.IntegerOption.String()) -} - -func (m ParameterizedTypeSingleIntegerParam[T]) baseString() string { - switch any(m).(type) { - case ParameterizedVarCharType: - t := VarCharType{} - return t.BaseString() - case ParameterizedFixedCharType: - t := FixedCharType{} - return t.BaseString() - case ParameterizedFixedBinaryType: - t := FixedBinaryType{} - return t.BaseString() - case ParameterizedPrecisionTimestampType: - t := PrecisionTimestampType{} - return t.BaseString() - case ParameterizedPrecisionTimestampTzType: - t := PrecisionTimestampTzType{} - return t.BaseString() - default: - panic("unknown type") - } -} diff --git a/types/parameterized_types_test.go b/types/parameterized_types_test.go deleted file mode 100644 index e620e13..0000000 --- a/types/parameterized_types_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package types_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/substrait-io/substrait-go/types" -) - -func TestParameterizedVarCharType(t *testing.T) { - for _, td := range []struct { - name string - typ types.ParameterizedSingleIntegerType - nullability types.Nullability - integerOption types.IntegerParam - expectedString string - expectedBaseString string - expectedShortString string - }{ - {"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "varchar?", "varchar", "vchar"}, - {"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "varchar", "varchar", "vchar"}, - {"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "char?", "char", "fchar"}, - {"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "char", "char", "fchar"}, - {"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "fixedbinary?", "fixedbinary", "fbin"}, - {"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "fixedbinary", "fixedbinary", "fbin"}, - {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp?", "precision_timestamp", "prets"}, - {"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp", "precision_timestamp", "prets"}, - {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz?", "precision_timestamp_tz", "pretstz"}, - {"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz", "precision_timestamp_tz", "pretstz"}, - } { - t.Run(td.name, func(t *testing.T) { - pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability) - require.Equal(t, td.expectedString, pt.String()) - require.Equal(t, td.expectedShortString, pt.ShortString()) - require.True(t, pt.Equals(pt)) - }) - } -} - -func TestParameterizedDecimalType(t *testing.T) { - for _, td := range []struct { - name string - precision string - scale string - nullability types.Nullability - expectedString string - expectedBaseString string - expectedShortString string - }{ - {"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?", "decimal", "dec"}, - {"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal", "decimal", "dec"}, - } { - t.Run(td.name, func(t *testing.T) { - precision := types.IntegerParam{Name: td.precision} - scale := types.IntegerParam{Name: td.scale} - pt := types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability} - require.Equal(t, td.expectedString, pt.String()) - //require.Equal(t, td.expectedBaseString, pt.BaseString()) - require.Equal(t, td.expectedShortString, pt.ShortString()) - require.True(t, pt.Equals(pt)) - }) - } -} diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index 5ac099a..4e5abd0 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -12,6 +12,7 @@ import ( "github.com/alecthomas/participle/v2/lexer" substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" ) var defaultParser *Parser @@ -175,12 +176,17 @@ func (l *listType) Type() (types.Type, error) { if err != nil { return nil, err } + if abstractParam, ok1 := ret.(types.ParameterizedAbstractType); ok1 { + return &types.ParameterizedListType{ + Nullability: n, + Type: abstractParam, + }, nil + } return &types.ListType{ Nullability: n, Type: ret, }, nil } - return nil, substraitgo.ErrNotImplemented } @@ -250,7 +256,7 @@ func getFixedTypeFromConcreteParam(name string, param *IntegerLiteral) (types.Ty } func getParameterizedTypeSingleParam(typeName string, param *ParamName) (types.Type, error) { - intParam := types.IntegerParam{Name: param.Name} + intParam := parameter_types.LeafIntParamAbstractType(param.Name) switch types.TypeName(typeName) { case types.TypeNameVarChar: return types.ParameterizedVarCharType{IntegerOption: intParam}, nil @@ -280,7 +286,7 @@ func (d *decimalType) String() string { if d.Nullability { opt = "?" } - return "decimal" + opt + "<" + d.Precision.Expr.String() + ", " + d.Scale.Expr.String() + ">" + return "decimal" + opt + "<" + d.Precision.Expr.String() + "," + d.Scale.Expr.String() + ">" } func (d *decimalType) Optional() bool { return d.Nullability } @@ -292,28 +298,43 @@ func (d *decimalType) Type() (types.Type, error) { } else { n = types.NullabilityRequired } - pi, ok1 := d.Precision.Expr.(*IntegerLiteral) - si, ok2 := d.Scale.Expr.(*IntegerLiteral) - if ok1 && ok2 { + pi, isPrecisionConcrete := d.Precision.Expr.(*IntegerLiteral) + si, isScaleConcrete := d.Scale.Expr.(*IntegerLiteral) + if isPrecisionConcrete && isScaleConcrete { // concrete decimal param return &types.DecimalType{ Nullability: n, - Precision: pi.Value, - Scale: si.Value, + Precision: parameter_types.LeafIntParamConcreteType(pi.Value), + Scale: parameter_types.LeafIntParamConcreteType(si.Value), }, nil } - ps, ok1 := d.Precision.Expr.(*ParamName) - ss, ok2 := d.Scale.Expr.(*ParamName) - if ok1 && ok2 { - // parameterized decimal param - return types.ParameterizedDecimalType{ + // there is at least one abstract param, so it is parameterized type + + ps, isPrecisionAbstract := d.Precision.Expr.(*ParamName) + ss, isScaleAbstract := d.Scale.Expr.(*ParamName) + if isPrecisionAbstract && isScaleAbstract { + // both abstract param + return &types.ParameterizedDecimalType{ Nullability: n, - Precision: types.IntegerParam{Name: ps.Name}, - Scale: types.IntegerParam{Name: ss.Name}, + Precision: parameter_types.LeafIntParamAbstractType(ps.Name), + Scale: parameter_types.LeafIntParamAbstractType(ss.Name), }, nil } - return nil, substraitgo.ErrNotImplemented + + // one abstract and one concrete + if isPrecisionConcrete { + return &types.ParameterizedDecimalType{ + Nullability: n, + Precision: parameter_types.LeafIntParamConcreteType(pi.Value), + Scale: parameter_types.LeafIntParamAbstractType(ss.Name), + }, nil + } + return &types.ParameterizedDecimalType{ + Nullability: n, + Precision: parameter_types.LeafIntParamAbstractType(ps.Name), + Scale: parameter_types.LeafIntParamConcreteType(si.Value), + }, nil } type structType struct { @@ -351,6 +372,7 @@ func (t *structType) Type() (types.Type, error) { } var err error typeList := make([]types.Type, len(t.Types)) + anyAbstractParamPresent := false for i, typ := range t.Types { tp, ok := typ.Expr.(*Type) if !ok { @@ -360,6 +382,15 @@ func (t *structType) Type() (types.Type, error) { if typeList[i], err = tp.Type(); err != nil { return nil, err } + if _, ok1 := typeList[i].(types.ParameterizedAbstractType); ok1 { + anyAbstractParamPresent = true + } + } + if anyAbstractParamPresent { + return &types.ParameterizedStructType{ + Nullability: n, + Type: typeList, + }, nil } return &types.StructType{ Nullability: n, @@ -380,7 +411,7 @@ func (m *mapType) String() string { if m.Nullability { opt = "?" } - return "map" + opt + "<" + m.Key.Expr.String() + "," + m.Value.Expr.String() + ">" + return "map" + opt + "<" + m.Key.Expr.String() + ", " + m.Value.Expr.String() + ">" } func (m *mapType) Optional() bool { return m.Nullability } @@ -412,6 +443,22 @@ func (m *mapType) Type() (types.Type, error) { if err != nil { return nil, err } + + anyAbstractParamPresent := false + if _, ok1 := key.(types.ParameterizedAbstractType); ok1 { + anyAbstractParamPresent = true + } + if _, ok1 := value.(types.ParameterizedAbstractType); ok1 { + anyAbstractParamPresent = true + } + if anyAbstractParamPresent { + return &types.ParameterizedMapType{ + Key: key, + Value: value, + Nullability: n, + }, nil + } + return &types.MapType{ Key: key, Value: value, diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index ac66753..a33abe7 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parameter_types" "github.com/substrait-io/substrait-go/types/parser" ) @@ -24,22 +25,26 @@ func TestParser(t *testing.T) { {"i16?", "i16?", "i16", &types.Int16Type{Nullability: types.NullabilityNullable}}, {"boolean", "boolean", "bool", &types.BooleanType{Nullability: types.NullabilityRequired}}, {"fixedchar<5>", "fixedchar<5>", "fchar", &types.FixedCharType{Length: 5}}, - {"decimal<10,5>", "decimal<10, 5>", "dec", &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}}, - {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, - {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"decimal<10,5>", "decimal<10,5>", "dec", &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}}, + {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, {"struct", "struct", "struct", &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, - {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, - {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, + {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}, {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}}, - {"varchar", "varchar", "vchar", types.ParameterizedVarCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"fixedchar", "fixedchar", "fchar", types.ParameterizedFixedCharType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"fixedbinary", "fixedbinary", "fbin", types.ParameterizedFixedBinaryType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"precision_timestamp", "precision_timestamp", "prets", types.ParameterizedPrecisionTimestampType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", types.ParameterizedPrecisionTimestampTzType{IntegerOption: types.IntegerParam{Name: "L1"}}}, - {"decimal", "decimal", "dec", types.ParameterizedDecimalType{Precision: types.IntegerParam{Name: "P"}, Scale: types.IntegerParam{Name: "S"}, Nullability: types.NullabilityRequired}}, + {"varchar", "varchar", "vchar", types.ParameterizedVarCharType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, + {"fixedchar", "fixedchar", "fchar", types.ParameterizedFixedCharType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, + {"fixedbinary", "fixedbinary", "fbin", types.ParameterizedFixedBinaryType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, + {"precision_timestamp", "precision_timestamp", "prets", types.ParameterizedPrecisionTimestampType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, + {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", types.ParameterizedPrecisionTimestampTzType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, + {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}}, + {"decimal<38,S>", "decimal<38,S>", "dec", &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamConcreteType(38), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}}, {"any", "any", "any", types.AnyType{Nullability: types.NullabilityRequired}}, {"any1?", "any1?", "any", types.AnyType{Nullability: types.NullabilityNullable}}, + {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"struct>, i16>", "struct>, i16>", "struct", &types.ParameterizedStructType{Type: []types.Type{&types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Nullability: types.NullabilityNullable}, &types.Int16Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, + {"map, i16>", "map, i16>", "map", &types.ParameterizedMapType{Key: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Value: &types.Int16Type{Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, } p, err := parser.New() diff --git a/types/types.go b/types/types.go index 278d9a9..689629f 100644 --- a/types/types.go +++ b/types/types.go @@ -12,6 +12,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/proto" + "github.com/substrait-io/substrait-go/types/parameter_types" ) type Version = proto.Version @@ -271,26 +272,26 @@ func TypeFromProto(t *proto.Type) FuncArgType { return &FixedBinaryType{ Nullability: t.FixedBinary.Nullability, TypeVariationRef: t.FixedBinary.TypeVariationReference, - Length: t.FixedBinary.Length, + Length: parameter_types.LeafIntParamConcreteType(t.FixedBinary.Length), } case *proto.Type_FixedChar_: return &FixedCharType{ Nullability: t.FixedChar.Nullability, TypeVariationRef: t.FixedChar.TypeVariationReference, - Length: t.FixedChar.Length, + Length: parameter_types.LeafIntParamConcreteType(t.FixedChar.Length), } case *proto.Type_Varchar: return &VarCharType{ Nullability: t.Varchar.Nullability, TypeVariationRef: t.Varchar.TypeVariationReference, - Length: t.Varchar.Length, + Length: parameter_types.LeafIntParamConcreteType(t.Varchar.Length), } case *proto.Type_Decimal_: return &DecimalType{ Nullability: t.Decimal.Nullability, TypeVariationRef: t.Decimal.TypeVariationReference, - Scale: t.Decimal.Scale, - Precision: t.Decimal.Precision, + Scale: parameter_types.LeafIntParamConcreteType(t.Decimal.Scale), + Precision: parameter_types.LeafIntParamConcreteType(t.Decimal.Precision), } case *proto.Type_Struct_: fields := make([]Type, len(t.Struct.Types)) @@ -375,21 +376,22 @@ type ( WithNullability(Nullability) Type } - // ParameterizedType this representa a concrete type with parameters - ParameterizedType interface { + // ParameterizedConcreteType this represents a concrete type with parameters + ParameterizedConcreteType interface { Type ParameterString() string BaseString() string } - FixedType interface { - ParameterizedType - WithLength(int32) FixedType + // ParameterizedAbstractType this represents a type which has at least one abstract parameter + ParameterizedAbstractType interface { + Type + GetAbstractParameters() []parameter_types.AbstractParameterType } - ParameterizedSingleIntegerType interface { - Type - WithIntegerOption(param IntegerParam) Type + FixedType interface { + ParameterizedConcreteType + WithLength(int32) FixedType } ) @@ -480,19 +482,19 @@ func TypeToProto(t Type) *proto.Type { case *FixedCharType: return &proto.Type{Kind: &proto.Type_FixedChar_{ FixedChar: &proto.Type_FixedChar{ - Length: t.Length, + Length: t.Length.ToProtoVal(), Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *VarCharType: return &proto.Type{Kind: &proto.Type_Varchar{ Varchar: &proto.Type_VarChar{ - Length: t.Length, + Length: t.Length.ToProtoVal(), Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *FixedBinaryType: return &proto.Type{Kind: &proto.Type_FixedBinary_{ FixedBinary: &proto.Type_FixedBinary{ - Length: t.Length, + Length: t.Length.ToProtoVal(), Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *DecimalType: @@ -637,11 +639,11 @@ type ( FixedCharType = FixedLenType[FixedChar] VarCharType = FixedLenType[VarChar] FixedBinaryType = FixedLenType[FixedBinary] - ParameterizedVarCharType = ParameterizedTypeSingleIntegerParam[VarCharType] - ParameterizedFixedCharType = ParameterizedTypeSingleIntegerParam[FixedCharType] - ParameterizedFixedBinaryType = ParameterizedTypeSingleIntegerParam[FixedBinaryType] - ParameterizedPrecisionTimestampType = ParameterizedTypeSingleIntegerParam[PrecisionTimestampType] - ParameterizedPrecisionTimestampTzType = ParameterizedTypeSingleIntegerParam[PrecisionTimestampTzType] + ParameterizedVarCharType = parameterizedTypeSingleIntegerParam[VarCharType] + ParameterizedFixedCharType = parameterizedTypeSingleIntegerParam[FixedCharType] + ParameterizedFixedBinaryType = parameterizedTypeSingleIntegerParam[FixedBinaryType] + ParameterizedPrecisionTimestampType = parameterizedTypeSingleIntegerParam[PrecisionTimestampType] + ParameterizedPrecisionTimestampTzType = parameterizedTypeSingleIntegerParam[PrecisionTimestampTzType] ) // FixedLenType is any of the types which also need to track their specific @@ -649,7 +651,7 @@ type ( type FixedLenType[T FixedChar | VarChar | FixedBinary] struct { Nullability Nullability TypeVariationRef uint32 - Length int32 + Length parameter_types.LeafIntParamConcreteType } func (*FixedLenType[T]) isRootRef() {} @@ -698,7 +700,7 @@ func (s *FixedLenType[T]) BaseString() string { func (s *FixedLenType[T]) WithLength(length int32) FixedType { out := *s - out.Length = length + out.Length = parameter_types.LeafIntParamConcreteType(length) return &out } @@ -706,7 +708,7 @@ func (s *FixedLenType[T]) WithLength(length int32) FixedType { type DecimalType struct { Nullability Nullability TypeVariationRef uint32 - Scale, Precision int32 + Scale, Precision parameter_types.LeafIntParamConcreteType } func (*DecimalType) isRootRef() {} @@ -736,7 +738,7 @@ func (s *DecimalType) ToProtoFuncArg() *proto.FunctionArgument { func (s *DecimalType) ToProto() *proto.Type { return &proto.Type{Kind: &proto.Type_Decimal_{ Decimal: &proto.Type_Decimal{ - Scale: s.Scale, Precision: s.Precision, + Scale: s.Scale.ToProtoVal(), Precision: s.Precision.ToProtoVal(), Nullability: s.Nullability, TypeVariationReference: s.TypeVariationRef}}} } @@ -828,6 +830,21 @@ func (t *StructType) String() string { return b.String() } +func (t *StructType) ParameterString() string { + sb := strings.Builder{} + for i, typ := range t.Types { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(typ.String()) + } + return sb.String() +} + +func (*StructType) BaseString() string { + return "struct" +} + type ListType struct { Nullability Nullability TypeVariationRef uint32 @@ -879,6 +896,14 @@ func (t *ListType) String() string { return "list" + strNullable(t) + "<" + t.Type.String() + ">" } +func (s *ListType) ParameterString() string { + return s.Type.String() +} + +func (*ListType) BaseString() string { + return "list" +} + type MapType struct { Nullability Nullability TypeVariationRef uint32 @@ -927,7 +952,15 @@ func (t *MapType) ToProtoFuncArg() *proto.FunctionArgument { func (t *MapType) ShortString() string { return "map" } func (t *MapType) String() string { - return "map" + strNullable(t) + "<" + t.Key.String() + "," + t.Value.String() + ">" + return "map" + strNullable(t) + "<" + t.Key.String() + ", " + t.Value.String() + ">" +} + +func (t *MapType) ParameterString() string { + return fmt.Sprintf("%s, %s", t.Key.String(), t.Value.String()) +} + +func (*MapType) BaseString() string { + return "map" } // TypeParam represents a type parameter for a user defined type diff --git a/types/types_test.go b/types/types_test.go index f705c8b..49e6e8c 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -42,7 +42,7 @@ func TestTypeToString(t *testing.T) { "struct?>", "struct"}, {&ListType{Type: &Int8Type{}}, "list", "list"}, {&MapType{Key: &StringType{}, Value: &DecimalType{Precision: 10, Scale: 2}}, - "map>", "map"}, + "map>", "map"}, } for _, tt := range tests { From 17834d49c12011f8007b0c2992afaccc18e8a4c7 Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Wed, 11 Sep 2024 19:25:00 +0530 Subject: [PATCH 5/6] Address review comments * Made parameter of function argument as separate interface "FuncDefArgType" --- expr/functions.go | 14 +- expr/literals.go | 25 +- extensions/variants.go | 30 +- extensions/variants_test.go | 32 +- functions/types.go | 17 +- literal/utils_test.go | 7 +- types/any_type.go | 31 +- types/any_type_test.go | 7 +- types/leaf_parameters/concrete_int_param.go | 26 ++ .../leaf_parameter_type.go | 15 +- .../leaf_parameter_type_test.go | 16 +- types/leaf_parameters/variable_int_param.go | 30 ++ .../abstract_parameter_type.go | 23 -- .../concrete_parameter_type.go | 23 -- types/parameterized_decimal_type.go | 61 ++-- types/parameterized_decimal_type_test.go | 65 ++-- types/parameterized_list_type.go | 44 +-- types/parameterized_list_type_test.go | 40 ++- types/parameterized_map_type.go | 58 ++-- types/parameterized_map_type_test.go | 45 ++- ...parameterized_single_integer_param_type.go | 84 ++--- ...eterized_single_integer_param_type_test.go | 54 ++-- types/parameterized_struct_type.go | 72 ++--- types/parameterized_struct_type_test.go | 42 ++- types/parser/type_parser.go | 290 +++++++++++------- types/parser/type_parser_test.go | 52 ++-- types/precison_timestamp_types.go | 2 + types/precison_timestamp_types_test.go | 2 + types/types.go | 73 +++-- 29 files changed, 617 insertions(+), 663 deletions(-) create mode 100644 types/leaf_parameters/concrete_int_param.go rename types/{parameter_types => leaf_parameters}/leaf_parameter_type.go (53%) rename types/{parameter_types => leaf_parameters}/leaf_parameter_type_test.go (62%) create mode 100644 types/leaf_parameters/variable_int_param.go delete mode 100644 types/parameter_types/abstract_parameter_type.go delete mode 100644 types/parameter_types/concrete_parameter_type.go diff --git a/expr/functions.go b/expr/functions.go index 9b16687..1132e78 100644 --- a/expr/functions.go +++ b/expr/functions.go @@ -478,11 +478,13 @@ func (w *WindowFunction) Invocation() types.AggregationInvocation { return w.inv func (w *WindowFunction) Decomposable() extensions.DecomposeType { return w.declaration.Decomposability() } -func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() } -func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() } -func (w *WindowFunction) IntermediateType() (types.Type, error) { return w.declaration.Intermediate() } -func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() } -func (*WindowFunction) IsScalar() bool { return false } +func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() } +func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() } +func (w *WindowFunction) IntermediateType() (types.FuncDefArgType, error) { + return w.declaration.Intermediate() +} +func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() } +func (*WindowFunction) IsScalar() bool { return false } func (*WindowFunction) isRootRef() {} @@ -773,7 +775,7 @@ func (a *AggregateFunction) Decomposable() extensions.DecomposeType { } func (a *AggregateFunction) Ordered() bool { return a.declaration.Ordered() } func (a *AggregateFunction) MaxSet() int { return a.declaration.MaxSet() } -func (a *AggregateFunction) IntermediateType() (types.Type, error) { +func (a *AggregateFunction) IntermediateType() (types.FuncDefArgType, error) { return a.declaration.Intermediate() } diff --git a/expr/literals.go b/expr/literals.go index bb45a79..dc8863e 100644 --- a/expr/literals.go +++ b/expr/literals.go @@ -12,7 +12,6 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/proto" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" "golang.org/x/exp/slices" "google.golang.org/protobuf/types/known/anypb" ) @@ -455,8 +454,8 @@ func (t *ProtoLiteral) ToProtoLiteral() *proto.Expression_Literal { lit.LiteralType = &proto.Expression_Literal_Decimal_{ Decimal: &proto.Expression_Literal_Decimal{ Value: v, - Precision: literalType.Precision.ToProtoVal(), - Scale: literalType.Scale.ToProtoVal(), + Precision: literalType.Precision, + Scale: literalType.Scale, }, } case *types.PrecisionTimestampType: @@ -528,7 +527,7 @@ func NewFixedCharLiteral(val types.FixedChar, nullable bool) *PrimitiveLiteral[t Value: val, Type: &types.FixedCharType{ Nullability: getNullability(nullable), - Length: parameter_types.LeafIntParamConcreteType(len(val)), + Length: int32(len(val)), }, } } @@ -615,7 +614,7 @@ func NewFixedBinaryLiteral(val types.FixedBinary, nullable bool) *ByteSliceLiter return &ByteSliceLiteral[types.FixedBinary]{ Value: val, Type: &types.FixedLenType[types.FixedBinary]{ - Length: parameter_types.LeafIntParamConcreteType(len(val)), + Length: int32(len(val)), Nullability: getNullability(nullable), }, } @@ -687,8 +686,8 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) { Value: v.Value, Type: &types.DecimalType{ Nullability: getNullability(nullable), - Precision: parameter_types.LeafIntParamConcreteType(v.Precision), - Scale: parameter_types.LeafIntParamConcreteType(v.Scale), + Precision: v.Precision, + Scale: v.Scale, }, }, nil case *types.UserDefinedLiteral: @@ -710,7 +709,7 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) { Value: v.Value, Type: &types.VarCharType{ Nullability: getNullability(nullable), - Length: parameter_types.LeafIntParamConcreteType(v.Length), + Length: int32(v.Length), }, }, nil case *types.PrecisionTimestamp: @@ -824,7 +823,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &PrimitiveLiteral[types.FixedChar]{ Value: types.FixedChar(lit.FixedChar), Type: &types.FixedCharType{ - Length: parameter_types.LeafIntParamConcreteType(len(lit.FixedChar)), + Length: int32(len(lit.FixedChar)), TypeVariationRef: l.TypeVariationReference, Nullability: nullability, }} @@ -832,7 +831,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ProtoLiteral{ Value: lit.VarChar.Value, Type: &types.VarCharType{ - Length: parameter_types.LeafIntParamConcreteType(lit.VarChar.Length), + Length: int32(lit.VarChar.Length), Nullability: nullability, TypeVariationRef: l.TypeVariationReference, }, @@ -841,7 +840,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ByteSliceLiteral[types.FixedBinary]{ Value: lit.FixedBinary, Type: &types.FixedBinaryType{ - Length: parameter_types.LeafIntParamConcreteType(len(lit.FixedBinary)), + Length: int32(len(lit.FixedBinary)), TypeVariationRef: l.TypeVariationReference, Nullability: nullability, }} @@ -849,8 +848,8 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal { return &ProtoLiteral{ Value: lit.Decimal.Value, Type: &types.DecimalType{ - Scale: parameter_types.LeafIntParamConcreteType(lit.Decimal.Scale), - Precision: parameter_types.LeafIntParamConcreteType(lit.Decimal.Precision), + Scale: lit.Decimal.Scale, + Precision: lit.Decimal.Precision, Nullability: nullability, TypeVariationRef: l.TypeVariationReference, }, diff --git a/extensions/variants.go b/extensions/variants.go index 2a14d54..3ee85d7 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -8,7 +8,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -66,7 +66,7 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx var outType types.Type if t, ok := expr.Expr.(*parser.Type); ok { var err error - outType, err = t.Type() + outType, err = t.RetType() if err != nil { return nil, err } @@ -260,9 +260,9 @@ func (s *AggregateFunctionVariant) ID() ID { return ID{URI: s.uri, Name: s.CompoundName()} } func (s *AggregateFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable } -func (s *AggregateFunctionVariant) Intermediate() (types.Type, error) { +func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error) { if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok { - return t.Type() + return t.ArgType() } return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType) } @@ -367,9 +367,9 @@ func (s *WindowFunctionVariant) ID() ID { return ID{URI: s.uri, Name: s.CompoundName()} } func (s *WindowFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable } -func (s *WindowFunctionVariant) Intermediate() (types.Type, error) { +func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) { if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok { - return t.Type() + return t.ArgType() } return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType) } @@ -378,20 +378,19 @@ func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet } func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType } // HasSyncParams This API returns if params share a leaf param name -func HasSyncParams(params []types.Type) bool { +func HasSyncParams(params []types.FuncDefArgType) bool { // get list of parameters from Abstract parameter type // if any of the parameter is common, it indicates parameters are same across parameters existingParamMap := make(map[string]bool) for _, p := range params { - pat, ok := p.(types.ParameterizedAbstractType) - if !ok { + if !p.HasParameterizedParam() { // not a type which contains abstract parameters, so continue continue } // get list of parameters for each abstract parameter type // note, this can be more than one parameter because of nested abstract types // e.g. Decimal or List, VARCHAR>> - abstractParams := pat.GetAbstractParameters() + abstractParams := p.GetParameterizedParams() var leafParams []string for _, abstractParam := range abstractParams { leafParams = append(leafParams, getLeafAbstractParams(abstractParam)...) @@ -417,16 +416,19 @@ func HasSyncParams(params []types.Type) bool { // an abstract parameter can be a leaf type or a parameterized type itself // if it is a leaf type, its param name is returned // if it is parameterized type, leaf type is found recursively -func getLeafAbstractParams(abstractTypes parameter_types.AbstractParameterType) []string { +func getLeafAbstractParams(abstractTypes interface{}) []string { + if leaf, ok := abstractTypes.(leaf_parameters.LeafParameter); ok { + return []string{leaf.String()} + } // if it is not a leaf type recurse - if pat, ok := abstractTypes.(types.ParameterizedAbstractType); ok { + if pat, ok := abstractTypes.(types.FuncDefArgType); ok { var outLeafParams []string - for _, p := range pat.GetAbstractParameters() { + for _, p := range pat.GetParameterizedParams() { childLeafParams := getLeafAbstractParams(p) outLeafParams = append(outLeafParams, childLeafParams...) } return outLeafParams } // for leaf type, return the param name - return []string{abstractTypes.GetAbstractParamName()} + panic("invalid non-leaf, non-parameterized type param") } diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 9477b47..88b4a23 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -70,9 +70,9 @@ func TestEvaluateTypeExpression(t *testing.T) { func TestHasSyncParams(t *testing.T) { - apt_P := parameter_types.LeafIntParamAbstractType("P") - apt_Q := parameter_types.LeafIntParamAbstractType("Q") - cpt_38 := parameter_types.LeafIntParamConcreteType(38) + apt_P := leaf_parameters.NewVariableIntParam("P") + apt_Q := leaf_parameters.NewVariableIntParam("Q") + cpt_38 := leaf_parameters.NewConcreteIntParam(38) fct_P := &types.ParameterizedFixedCharType{IntegerOption: apt_P} fct_Q := &types.ParameterizedFixedCharType{IntegerOption: apt_Q} @@ -80,22 +80,22 @@ func TestHasSyncParams(t *testing.T) { decimal_38_Q := &types.ParameterizedDecimalType{Precision: cpt_38, Scale: apt_Q} list_decimal_38_Q := &types.ParameterizedListType{Type: decimal_38_Q} map_fctQ_decimal38Q := &types.ParameterizedMapType{Key: fct_Q, Value: decimal_38_Q} - struct_fctQ_ListDecimal38Q := &types.ParameterizedStructType{Type: []types.Type{fct_Q, list_decimal_38_Q}} + struct_fctQ_ListDecimal38Q := &types.ParameterizedStructType{Types: []types.FuncDefArgType{fct_Q, list_decimal_38_Q}} for _, td := range []struct { name string - params []types.Type + params []types.FuncDefArgType expectedHasSyncParams bool }{ - {"No Abstract Type", []types.Type{&types.Int64Type{}}, false}, - {"No Sync Param P, Q", []types.Type{fct_P, fct_Q}, false}, - {"Sync Params P, P", []types.Type{fct_P, fct_P}, true}, - {"Sync Params P, ", []types.Type{fct_P, decimal_PQ}, true}, - {"No Sync Params P, <38, Q>", []types.Type{fct_P, decimal_38_Q}, false}, - {"Sync Params P, List>", []types.Type{fct_P, list_decimal_38_Q}, false}, - {"No Sync Params fct

, Map, decimal<38,Q>>", []types.Type{fct_P, map_fctQ_decimal38Q}, false}, - {"Sync Params fct, Map, decimal<38,Q>>", []types.Type{fct_Q, map_fctQ_decimal38Q}, true}, - {"No Sync Params fct

, struct, list<38,Q>>", []types.Type{fct_P, struct_fctQ_ListDecimal38Q}, false}, - {"Sync Params fct, struct, list<38,Q>>", []types.Type{fct_Q, struct_fctQ_ListDecimal38Q}, true}, + {"No Abstract Type", []types.FuncDefArgType{&types.Int64Type{}}, false}, + {"No Sync Param P, Q", []types.FuncDefArgType{fct_P, fct_Q}, false}, + {"Sync Params P, P", []types.FuncDefArgType{fct_P, fct_P}, true}, + {"Sync Params P, ", []types.FuncDefArgType{fct_P, decimal_PQ}, true}, + {"No Sync Params P, <38, Q>", []types.FuncDefArgType{fct_P, decimal_38_Q}, false}, + {"Sync Params P, List>", []types.FuncDefArgType{fct_P, list_decimal_38_Q}, false}, + {"No Sync Params fct

, Map, decimal<38,Q>>", []types.FuncDefArgType{fct_P, map_fctQ_decimal38Q}, false}, + {"Sync Params fct, Map, decimal<38,Q>>", []types.FuncDefArgType{fct_Q, map_fctQ_decimal38Q}, true}, + {"No Sync Params fct

, struct, list<38,Q>>", []types.FuncDefArgType{fct_P, struct_fctQ_ListDecimal38Q}, false}, + {"Sync Params fct, struct, list<38,Q>>", []types.FuncDefArgType{fct_Q, struct_fctQ_ListDecimal38Q}, true}, } { t.Run(td.name, func(t *testing.T) { if td.expectedHasSyncParams { diff --git a/functions/types.go b/functions/types.go index c4a6266..3de9d06 100644 --- a/functions/types.go +++ b/functions/types.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package functions import ( @@ -6,7 +8,6 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" ) var ( @@ -84,18 +85,18 @@ func getTypeWithParameters(typ types.Type, parameters []int32) (types.Type, erro if len(parameters) != 2 { return nil, substraitgo.ErrInvalidType } - return &types.DecimalType{Precision: parameter_types.LeafIntParamConcreteType(parameters[0]), Scale: parameter_types.LeafIntParamConcreteType(parameters[1])}, nil + return &types.DecimalType{Precision: parameters[0], Scale: parameters[1]}, nil case *types.FixedBinaryType, *types.FixedCharType, *types.VarCharType: if len(parameters) != 1 { return nil, substraitgo.ErrInvalidType } switch typ.(type) { case *types.FixedBinaryType: - return &types.FixedBinaryType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil + return &types.FixedBinaryType{Length: parameters[0]}, nil case *types.FixedCharType: - return &types.FixedCharType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil + return &types.FixedCharType{Length: parameters[0]}, nil case *types.VarCharType: - return &types.VarCharType{Length: parameter_types.LeafIntParamConcreteType(parameters[0])}, nil + return &types.VarCharType{Length: parameters[0]}, nil } default: if len(parameters) != 0 { @@ -148,14 +149,14 @@ type typeInfo struct { func (ti *typeInfo) getLongName() string { switch ti.typ.(type) { - case types.ParameterizedConcreteType: - return ti.typ.(types.ParameterizedConcreteType).BaseString() + case types.CompositeType: + return ti.typ.(types.CompositeType).BaseString() } return ti.typ.String() } func (ti *typeInfo) getLocalTypeString(input types.Type, enclosure typeEnclosure) string { - if paramType, ok := input.(types.ParameterizedConcreteType); ok { + if paramType, ok := input.(types.CompositeType); ok { return ti.localName + enclosure.containerStart() + paramType.ParameterString() + enclosure.containerEnd() } return ti.localName diff --git a/literal/utils_test.go b/literal/utils_test.go index 9a3fefc..c6848c7 100644 --- a/literal/utils_test.go +++ b/literal/utils_test.go @@ -12,7 +12,6 @@ import ( "github.com/substrait-io/substrait-go/expr" "github.com/substrait-io/substrait-go/proto" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" ) func TestNewBool(t *testing.T) { @@ -90,8 +89,8 @@ func createDecimalLiteral(value []byte, precision int32, scale int32, isNullable Value: value[:16], Type: &types.DecimalType{ Nullability: nullability, - Precision: parameter_types.LeafIntParamConcreteType(precision), - Scale: parameter_types.LeafIntParamConcreteType(scale), + Precision: precision, + Scale: scale, }, } } @@ -675,7 +674,7 @@ func createVarCharLiteral(value string) *expr.ProtoLiteral { Value: value, Type: &types.VarCharType{ Nullability: proto.Type_NULLABILITY_REQUIRED, - Length: parameter_types.LeafIntParamConcreteType(len(value)), + Length: int32(len(value)), }, } } diff --git a/types/any_type.go b/types/any_type.go index 43aad3b..f6561ae 100644 --- a/types/any_type.go +++ b/types/any_type.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( @@ -12,24 +14,21 @@ type AnyType struct { Nullability Nullability } -func (AnyType) isRootRef() {} -func (m AnyType) WithNullability(nullability Nullability) Type { - m.Nullability = nullability - return m -} -func (m AnyType) GetType() Type { return m } -func (m AnyType) GetNullability() Nullability { - return m.Nullability +func (t AnyType) SetNullability(n Nullability) FuncDefArgType { + t.Nullability = n + return t } -func (m AnyType) GetTypeVariationReference() uint32 { - return m.TypeVariationRef + +func (t AnyType) String() string { + return fmt.Sprintf("%s%s", t.Name, strFromNullability(t.Nullability)) } -func (AnyType) Equals(rhs Type) bool { - // equal to every other type - return true + +func (s AnyType) HasParameterizedParam() bool { + // primitive type doesn't have abstract parameters + return false } -func (t AnyType) ShortString() string { return t.Name } -func (t AnyType) String() string { - return fmt.Sprintf("%s%s", t.Name, strNullable(t)) +func (s AnyType) GetParameterizedParams() []interface{} { + // any type doesn't have any abstract parameters + return nil } diff --git a/types/any_type_test.go b/types/any_type_test.go index dec44cb..1e8f26a 100644 --- a/types/any_type_test.go +++ b/types/any_type_test.go @@ -24,13 +24,8 @@ func TestAnyType(t *testing.T) { Name: td.argName, Nullability: td.nullability, } - anyType := arg.WithNullability(td.nullability) + anyType := arg.SetNullability(td.nullability) require.Equal(t, td.expectedString, anyType.String()) - require.Equal(t, td.nullability, anyType.GetNullability()) - require.Equal(t, td.argName, anyType.ShortString()) - // any type should be equal to any other type including itself - require.True(t, anyType.Equals(anyType)) - require.True(t, anyType.Equals(&types.Int8Type{})) }) } } diff --git a/types/leaf_parameters/concrete_int_param.go b/types/leaf_parameters/concrete_int_param.go new file mode 100644 index 0000000..30f1e32 --- /dev/null +++ b/types/leaf_parameters/concrete_int_param.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 + +package leaf_parameters + +import "fmt" + +// ConcreteIntParam represents a single integer concrete parameter for a concrete type +// Example: VARCHAR(6) -> 6 is an ConcreteIntParam +// DECIMAL --> 0 Is an ConcreteIntParam but P not +type ConcreteIntParam int32 + +func NewConcreteIntParam(v int32) LeafParameter { + m := ConcreteIntParam(v) + return &m +} + +func (m *ConcreteIntParam) IsCompatible(o LeafParameter) bool { + if t, ok := o.(*ConcreteIntParam); ok { + return t == m + } + return false +} + +func (m *ConcreteIntParam) String() string { + return fmt.Sprintf("%d", *m) +} diff --git a/types/parameter_types/leaf_parameter_type.go b/types/leaf_parameters/leaf_parameter_type.go similarity index 53% rename from types/parameter_types/leaf_parameter_type.go rename to types/leaf_parameters/leaf_parameter_type.go index d6e4504..ba2692b 100644 --- a/types/parameter_types/leaf_parameter_type.go +++ b/types/leaf_parameters/leaf_parameter_type.go @@ -1,4 +1,8 @@ -package parameter_types +// SPDX-License-Identifier: Apache-2.0 + +package leaf_parameters + +import "fmt" // LeafParameter represents a parameter type // parameter can of concrete (38) or abstract type (P) @@ -7,12 +11,5 @@ type LeafParameter interface { // IsCompatible is type compatible with other // compatible is other can be used in place of this type IsCompatible(other LeafParameter) bool - String() string -} - -// AbstractParameterType represents a parameter type which is abstract. -// it can be a leaf parameter (LeafIntParamAbstractType) -// or another abstract type like "DECIMAL" -type AbstractParameterType interface { - GetAbstractParamName() string + fmt.Stringer } diff --git a/types/parameter_types/leaf_parameter_type_test.go b/types/leaf_parameters/leaf_parameter_type_test.go similarity index 62% rename from types/parameter_types/leaf_parameter_type_test.go rename to types/leaf_parameters/leaf_parameter_type_test.go index 4ef799c..5a7a61f 100644 --- a/types/parameter_types/leaf_parameter_type_test.go +++ b/types/leaf_parameters/leaf_parameter_type_test.go @@ -1,24 +1,26 @@ -package parameter_types_test +// SPDX-License-Identifier: Apache-2.0 + +package leaf_parameters_test import ( "testing" "github.com/stretchr/testify/require" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) func TestConcreteParameterType(t *testing.T) { - concreteType1 := parameter_types.LeafIntParamConcreteType(1) + concreteType1 := leaf_parameters.ConcreteIntParam(1) require.Equal(t, "1", concreteType1.String()) } func TestLeafParameterType(t *testing.T) { - var concreteType1, concreteType2, abstractType1 parameter_types.LeafParameter + var concreteType1, concreteType2, abstractType1 leaf_parameters.LeafParameter - concreteType1 = parameter_types.LeafIntParamConcreteType(1) - concreteType2 = parameter_types.LeafIntParamConcreteType(2) + concreteType1 = leaf_parameters.NewConcreteIntParam(1) + concreteType2 = leaf_parameters.NewConcreteIntParam(2) - abstractType1 = parameter_types.LeafIntParamAbstractType("P") + abstractType1 = leaf_parameters.NewVariableIntParam("P") // verify string val require.Equal(t, "1", concreteType1.String()) diff --git a/types/leaf_parameters/variable_int_param.go b/types/leaf_parameters/variable_int_param.go new file mode 100644 index 0000000..abc3612 --- /dev/null +++ b/types/leaf_parameters/variable_int_param.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 + +package leaf_parameters + +// VariableIntParam represents an integer parameter for a parameterized type +// Example: VARCHAR(L1) -> L1 is an VariableIntParam +// DECIMAL --> P Is an VariableIntParam +type VariableIntParam string + +func NewVariableIntParam(s string) LeafParameter { + m := VariableIntParam(s) + return &m +} + +func (m *VariableIntParam) IsCompatible(o LeafParameter) bool { + switch o.(type) { + case *VariableIntParam, *ConcreteIntParam: + return true + default: + return false + } +} + +func (m *VariableIntParam) String() string { + return string(*m) +} + +func (m *VariableIntParam) GetAbstractParamName() string { + return string(*m) +} diff --git a/types/parameter_types/abstract_parameter_type.go b/types/parameter_types/abstract_parameter_type.go deleted file mode 100644 index 3e056c5..0000000 --- a/types/parameter_types/abstract_parameter_type.go +++ /dev/null @@ -1,23 +0,0 @@ -package parameter_types - -// LeafIntParamAbstractType represents an integer parameter for a parameterized type -// Example: VARCHAR(L1) -> L1 is an LeafIntParamAbstractType -// DECIMAL --> P Is an LeafIntParamAbstractType -type LeafIntParamAbstractType string - -func (m LeafIntParamAbstractType) IsCompatible(o LeafParameter) bool { - switch o.(type) { - case LeafIntParamAbstractType, LeafIntParamConcreteType: - return true - default: - return false - } -} - -func (m LeafIntParamAbstractType) String() string { - return string(m) -} - -func (m LeafIntParamAbstractType) GetAbstractParamName() string { - return string(m) -} diff --git a/types/parameter_types/concrete_parameter_type.go b/types/parameter_types/concrete_parameter_type.go deleted file mode 100644 index 9ea6603..0000000 --- a/types/parameter_types/concrete_parameter_type.go +++ /dev/null @@ -1,23 +0,0 @@ -package parameter_types - -import "fmt" - -// LeafIntParamConcreteType represents a single integer concrete parameter for a concrete type -// Example: VARCHAR(6) -> 6 is an LeafIntParamConcreteType -// DECIMAL --> 0 Is an LeafIntParamConcreteType but P not -type LeafIntParamConcreteType int32 - -func (m LeafIntParamConcreteType) IsCompatible(o LeafParameter) bool { - if t, ok := o.(LeafIntParamConcreteType); ok { - return t == m - } - return false -} - -func (m LeafIntParamConcreteType) String() string { - return fmt.Sprintf("%d", m) -} - -func (m LeafIntParamConcreteType) ToProtoVal() int32 { - return int32(m) -} diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go index 05897e4..0e31c82 100644 --- a/types/parameterized_decimal_type.go +++ b/types/parameterized_decimal_type.go @@ -1,66 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) -// ParameterizedDecimalType is a decimal type with at least one of precision and scale parameters of string type -// example: Decimal or Decimal. -// Note concrete types e.g. Decimal(10, 2) are not represented by this type -// Concrete type is represented by DecimalType +// ParameterizedDecimalType is a decimal type which to hold function arguments +// example: Decimal or Decimal or Decimal(10, 2) type ParameterizedDecimalType struct { Nullability Nullability TypeVariationRef uint32 - Precision parameter_types.LeafParameter - Scale parameter_types.LeafParameter + Precision leaf_parameters.LeafParameter + Scale leaf_parameters.LeafParameter } -func (*ParameterizedDecimalType) isRootRef() {} -func (m *ParameterizedDecimalType) WithNullability(n Nullability) Type { +func (m *ParameterizedDecimalType) SetNullability(n Nullability) FuncDefArgType { m.Nullability = n return m } -func (m *ParameterizedDecimalType) GetType() Type { return m } -func (m *ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedDecimalType) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m *ParameterizedDecimalType) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedDecimalType); ok { - return *o == *m - } - return false -} - -func (m *ParameterizedDecimalType) ShortString() string { - t := DecimalType{} - return t.ShortString() -} - func (m *ParameterizedDecimalType) String() string { t := DecimalType{} parameterString := fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) - return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedDecimalType) HasParameterizedParam() bool { + _, ok1 := m.Precision.(*leaf_parameters.VariableIntParam) + _, ok2 := m.Scale.(*leaf_parameters.VariableIntParam) + return ok1 || ok2 } -// GetAbstractParameters returns the abstract parameter names -// this implements interface ParameterizedAbstractType -func (m *ParameterizedDecimalType) GetAbstractParameters() []parameter_types.AbstractParameterType { - var params []parameter_types.AbstractParameterType - if p, ok := m.Precision.(parameter_types.AbstractParameterType); ok { +func (m *ParameterizedDecimalType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + var params []interface{} + if p, ok := m.Precision.(*leaf_parameters.VariableIntParam); ok { params = append(params, p) } - if p, ok := m.Scale.(parameter_types.AbstractParameterType); ok { + if p, ok := m.Scale.(*leaf_parameters.VariableIntParam); ok { params = append(params, p) } return params } - -// GetAbstractParamName this implements interface AbstractParameterType -// to indicate ParameterizedDecimalType itself can be used as a parameter of abstract type too -func (m *ParameterizedDecimalType) GetAbstractParamName() string { - return m.String() -} diff --git a/types/parameterized_decimal_type_test.go b/types/parameterized_decimal_type_test.go index 87990af..63afc09 100644 --- a/types/parameterized_decimal_type_test.go +++ b/types/parameterized_decimal_type_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types_test import ( @@ -5,55 +7,34 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) func TestParameterizedDecimalType(t *testing.T) { + precision_P := leaf_parameters.NewVariableIntParam("P") + scale_S := leaf_parameters.NewVariableIntParam("S") + precision_38 := leaf_parameters.NewConcreteIntParam(38) + scale_5 := leaf_parameters.NewConcreteIntParam(5) for _, td := range []struct { - name string - precision string - scale string - nullability types.Nullability - expectedString string - expectedShortString string + name string + precision leaf_parameters.LeafParameter + scale leaf_parameters.LeafParameter + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} }{ - {"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?", "dec"}, - {"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal", "dec"}, + {"both parameterized", precision_P, scale_S, "decimal?", "decimal", true, []interface{}{precision_P, scale_S}}, + {"precision concrete", precision_38, scale_S, "decimal?<38,S>", "decimal<38,S>", true, []interface{}{scale_S}}, + {"scale concrete", precision_P, scale_5, "decimal?", "decimal", true, []interface{}{precision_P}}, + {"both concrete", precision_38, scale_5, "decimal?<38,5>", "decimal<38,5>", false, nil}, } { t.Run(td.name, func(t *testing.T) { - precision := parameter_types.LeafIntParamAbstractType(td.precision) - scale := parameter_types.LeafIntParamAbstractType(td.scale) - pd := &types.ParameterizedDecimalType{Precision: precision, Scale: scale} - pdType := pd.WithNullability(td.nullability) - require.Equal(t, td.expectedString, pdType.String()) - require.Equal(t, td.expectedShortString, pdType.ShortString()) - require.True(t, pdType.Equals(pdType)) - - pdAbsParamType, ok := pdType.(parameter_types.AbstractParameterType) - require.True(t, ok) - require.Equal(t, td.expectedString, pdAbsParamType.GetAbstractParamName()) - pdAbstractType, ok := pdType.(types.ParameterizedAbstractType) - require.True(t, ok) - require.Len(t, pdAbstractType.GetAbstractParameters(), 2) + pd := &types.ParameterizedDecimalType{Precision: td.precision, Scale: td.scale} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) }) } } - -func TestParameterizedDecimalSingleAbstractParam(t *testing.T) { - precision := parameter_types.LeafIntParamConcreteType(38) - scale := parameter_types.LeafIntParamAbstractType("S") - - pd := &types.ParameterizedDecimalType{Precision: precision, Scale: scale} - pdType := pd.WithNullability(types.NullabilityNullable) - require.Equal(t, "decimal?<38,S>", pdType.String()) - require.Equal(t, "dec", pdType.ShortString()) - require.True(t, pdType.Equals(pdType)) - - pdAbsParamType, ok := pdType.(parameter_types.AbstractParameterType) - require.True(t, ok) - require.Equal(t, "decimal?<38,S>", pdAbsParamType.GetAbstractParamName()) - pdAbstractType, ok := pdType.(types.ParameterizedAbstractType) - require.True(t, ok) - // only one abstract param - require.Len(t, pdAbstractType.GetAbstractParameters(), 1) -} diff --git a/types/parameterized_list_type.go b/types/parameterized_list_type.go index 10a13b0..e37da33 100644 --- a/types/parameterized_list_type.go +++ b/types/parameterized_list_type.go @@ -1,9 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" - - "github.com/substrait-io/substrait-go/types/parameter_types" ) // ParameterizedListType is a list type having parameter of ParameterizedAbstractType @@ -13,47 +13,27 @@ import ( type ParameterizedListType struct { Nullability Nullability TypeVariationRef uint32 - Type ParameterizedAbstractType + Type FuncDefArgType } -func (*ParameterizedListType) isRootRef() {} -func (m *ParameterizedListType) WithNullability(n Nullability) Type { +func (m *ParameterizedListType) SetNullability(n Nullability) FuncDefArgType { m.Nullability = n return m } -func (m *ParameterizedListType) GetType() Type { return m } -func (m *ParameterizedListType) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedListType) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m *ParameterizedListType) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedListType); ok { - return m.Nullability == o.Nullability && m.TypeVariationRef == o.TypeVariationRef && - m.Type.Equals(o.Type) - } - return false -} - -func (m *ParameterizedListType) ShortString() string { - t := ListType{} - return t.ShortString() -} - func (m *ParameterizedListType) String() string { t := ListType{} parameterString := fmt.Sprintf("<%s>", m.Type) - return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) } -// GetAbstractParameters returns the abstract parameter names -// this implements interface ParameterizedAbstractType -func (m *ParameterizedListType) GetAbstractParameters() []parameter_types.AbstractParameterType { - return []parameter_types.AbstractParameterType{m.Type.(parameter_types.AbstractParameterType)} +func (m *ParameterizedListType) HasParameterizedParam() bool { + return m.Type.HasParameterizedParam() } -// GetAbstractParamName this implements interface AbstractParameterType -// to indicate ParameterizedListType itself can be used as a parameter of abstract type too -func (m *ParameterizedListType) GetAbstractParamName() string { - return m.String() +func (m *ParameterizedListType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + return []interface{}{m.Type} } diff --git a/types/parameterized_list_type_test.go b/types/parameterized_list_type_test.go index fc0a7b5..82166c2 100644 --- a/types/parameterized_list_type_test.go +++ b/types/parameterized_list_type_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types_test import ( @@ -5,37 +7,33 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) func TestParameterizedListType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: parameter_types.LeafIntParamAbstractType("P"), - Scale: parameter_types.LeafIntParamAbstractType("S"), + Precision: leaf_parameters.NewVariableIntParam("P"), + Scale: leaf_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } + int8Type := &types.Int8Type{} for _, td := range []struct { - name string - typ types.ParameterizedAbstractType - nullability types.Nullability - expectedString string - expectedShortString string + name string + param types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} }{ - {"list", decimalType, types.NullabilityNullable, "list?>", "list"}, + {"parameterized param", decimalType, "list?>", "list>", true, []interface{}{decimalType}}, + {"concrete param", int8Type, "list?", "list", false, nil}, } { t.Run(td.name, func(t *testing.T) { - pl := &types.ParameterizedListType{Type: td.typ} - plType := pl.WithNullability(td.nullability) - require.Equal(t, td.expectedString, plType.String()) - require.Equal(t, td.expectedShortString, plType.ShortString()) - require.True(t, plType.Equals(plType)) - - plAbsParamType, ok := plType.(parameter_types.AbstractParameterType) - require.True(t, ok) - require.Equal(t, td.expectedString, plAbsParamType.GetAbstractParamName()) - plAbstractType, ok := plType.(types.ParameterizedAbstractType) - require.True(t, ok) - require.Len(t, plAbstractType.GetAbstractParameters(), 1) + pd := &types.ParameterizedListType{Type: td.param} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) }) } } diff --git a/types/parameterized_map_type.go b/types/parameterized_map_type.go index 93c32fa..1c79c14 100644 --- a/types/parameterized_map_type.go +++ b/types/parameterized_map_type.go @@ -1,9 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" - - "github.com/substrait-io/substrait-go/types/parameter_types" ) // ParameterizedMapType is a struct having at least one of key or value of type ParameterizedAbstractType @@ -11,55 +11,35 @@ import ( type ParameterizedMapType struct { Nullability Nullability TypeVariationRef uint32 - Key Type - Value Type + Key FuncDefArgType + Value FuncDefArgType } -func (*ParameterizedMapType) isRootRef() {} -func (m *ParameterizedMapType) WithNullability(n Nullability) Type { +func (m *ParameterizedMapType) SetNullability(n Nullability) FuncDefArgType { m.Nullability = n return m } -func (m *ParameterizedMapType) GetType() Type { return m } -func (m *ParameterizedMapType) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedMapType) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m *ParameterizedMapType) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedMapType); ok { - return m.Nullability == o.Nullability && m.TypeVariationRef == o.TypeVariationRef && - m.Key.Equals(o.Key) && m.Value.Equals(o.Value) - } - return false -} - -func (m *ParameterizedMapType) ShortString() string { - t := MapType{} - return t.ShortString() -} - func (m *ParameterizedMapType) String() string { t := MapType{} parameterString := fmt.Sprintf("<%s, %s>", m.Key.String(), m.Value.String()) - return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedMapType) HasParameterizedParam() bool { + return m.Key.HasParameterizedParam() || m.Value.HasParameterizedParam() } -// GetAbstractParameters returns the abstract parameter names -// this implements interface ParameterizedAbstractType -func (m *ParameterizedMapType) GetAbstractParameters() []parameter_types.AbstractParameterType { - var abstractParams []parameter_types.AbstractParameterType - if abs, ok := m.Key.(parameter_types.AbstractParameterType); ok { - abstractParams = append(abstractParams, abs) +func (m *ParameterizedMapType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil } - if abs, ok := m.Value.(parameter_types.AbstractParameterType); ok { - abstractParams = append(abstractParams, abs) + var abstractParams []interface{} + if m.Key.HasParameterizedParam() { + abstractParams = append(abstractParams, m.Key) + } + if m.Value.HasParameterizedParam() { + abstractParams = append(abstractParams, m.Value) } return abstractParams } - -// GetAbstractParamName this implements interface AbstractParameterType -// to indicate ParameterizedStructType itself can be used as a parameter of abstract type too -func (m *ParameterizedMapType) GetAbstractParamName() string { - return m.String() -} diff --git a/types/parameterized_map_type_test.go b/types/parameterized_map_type_test.go index d882186..42cb04e 100644 --- a/types/parameterized_map_type_test.go +++ b/types/parameterized_map_type_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types_test import ( @@ -5,42 +7,37 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) func TestParameterizedMapType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: parameter_types.LeafIntParamAbstractType("P"), - Scale: parameter_types.LeafIntParamAbstractType("S"), + Precision: leaf_parameters.NewVariableIntParam("P"), + Scale: leaf_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} for _, td := range []struct { - name string - Key types.Type - Value types.Type - nullability types.Nullability - expectedString string - expectedShortString string - expectedNrAbstractParam int + name string + Key types.FuncDefArgType + Value types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} }{ - {"single abstract param", decimalType, int8Type, types.NullabilityNullable, "map?, i8?>", "map", 1}, - {"both abstract param", decimalType, listType, types.NullabilityNullable, "map?, list?>>", "map", 2}, + {"parameterized kv", decimalType, listType, "map?, list?>>", "map, list?>>", true, []interface{}{decimalType, listType}}, + {"concrete key", int8Type, listType, "map?>>", "map>>", true, []interface{}{listType}}, + {"concrete value", decimalType, int8Type, "map?, i8?>", "map, i8?>", true, []interface{}{decimalType}}, + {"no parameterized param", int8Type, int8Type, "map?", "map", false, nil}, } { t.Run(td.name, func(t *testing.T) { - pm := &types.ParameterizedMapType{Key: td.Key, Value: td.Value} - pmType := pm.WithNullability(td.nullability) - require.Equal(t, td.expectedString, pmType.String()) - require.Equal(t, td.expectedShortString, pmType.ShortString()) - require.True(t, pmType.Equals(pmType)) - - pmAbsParamType, ok := pmType.(parameter_types.AbstractParameterType) - require.True(t, ok) - require.Equal(t, td.expectedString, pmAbsParamType.GetAbstractParamName()) - pmAbstractType, ok := pmType.(types.ParameterizedAbstractType) - require.True(t, ok) - require.Len(t, pmAbstractType.GetAbstractParameters(), td.expectedNrAbstractParam) + pd := &types.ParameterizedMapType{Key: td.Key, Value: td.Value} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) }) } } diff --git a/types/parameterized_single_integer_param_type.go b/types/parameterized_single_integer_param_type.go index a417b1f..d5c215b 100644 --- a/types/parameterized_single_integer_param_type.go +++ b/types/parameterized_single_integer_param_type.go @@ -1,86 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) // parameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter type parameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { Nullability Nullability TypeVariationRef uint32 - IntegerOption parameter_types.LeafIntParamAbstractType + IntegerOption leaf_parameters.LeafParameter } -func (m parameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption parameter_types.LeafIntParamAbstractType) Type { - m.IntegerOption = integerOption - return m -} - -func (parameterizedTypeSingleIntegerParam[T]) isRootRef() {} -func (m parameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type { +func (m *parameterizedTypeSingleIntegerParam[T]) SetNullability(n Nullability) FuncDefArgType { m.Nullability = n return m } -func (m parameterizedTypeSingleIntegerParam[T]) GetType() Type { return m } -func (m parameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability } -func (m parameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m parameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool { - if o, ok := rhs.(parameterizedTypeSingleIntegerParam[T]); ok { - return o == m - } - return false -} - -func (m parameterizedTypeSingleIntegerParam[T]) ShortString() string { - switch any(m).(type) { - case ParameterizedVarCharType: - t := VarCharType{} - return t.ShortString() - case ParameterizedFixedCharType: - t := FixedCharType{} - return t.ShortString() - case ParameterizedFixedBinaryType: - t := FixedBinaryType{} - return t.ShortString() - case ParameterizedPrecisionTimestampType: - t := PrecisionTimestampType{} - return t.ShortString() - case ParameterizedPrecisionTimestampTzType: - t := PrecisionTimestampTzType{} - return t.ShortString() - default: - panic("unknown type") - } -} - -func (m parameterizedTypeSingleIntegerParam[T]) String() string { - return fmt.Sprintf("%s%s%s", m.baseString(), strNullable(m), m.parameterString()) +func (m *parameterizedTypeSingleIntegerParam[T]) String() string { + return fmt.Sprintf("%s%s%s", m.baseString(), strFromNullability(m.Nullability), m.parameterString()) } -func (m parameterizedTypeSingleIntegerParam[T]) parameterString() string { - return fmt.Sprintf("<%s>", m.IntegerOption.GetAbstractParamName()) +func (m *parameterizedTypeSingleIntegerParam[T]) parameterString() string { + return fmt.Sprintf("<%s>", m.IntegerOption.String()) } -func (m parameterizedTypeSingleIntegerParam[T]) baseString() string { +func (m *parameterizedTypeSingleIntegerParam[T]) baseString() string { switch any(m).(type) { - case ParameterizedVarCharType: + case *ParameterizedVarCharType: t := VarCharType{} return t.BaseString() - case ParameterizedFixedCharType: + case *ParameterizedFixedCharType: t := FixedCharType{} return t.BaseString() - case ParameterizedFixedBinaryType: + case *ParameterizedFixedBinaryType: t := FixedBinaryType{} return t.BaseString() - case ParameterizedPrecisionTimestampType: + case *ParameterizedPrecisionTimestampType: t := PrecisionTimestampType{} return t.BaseString() - case ParameterizedPrecisionTimestampTzType: + case *ParameterizedPrecisionTimestampTzType: t := PrecisionTimestampTzType{} return t.BaseString() default: @@ -88,14 +50,14 @@ func (m parameterizedTypeSingleIntegerParam[T]) baseString() string { } } -// GetAbstractParameters returns the abstract parameter names -// this implements interface ParameterizedAbstractType -func (m parameterizedTypeSingleIntegerParam[T]) GetAbstractParameters() []parameter_types.AbstractParameterType { - return []parameter_types.AbstractParameterType{m.IntegerOption} +func (m *parameterizedTypeSingleIntegerParam[T]) HasParameterizedParam() bool { + _, ok1 := m.IntegerOption.(*leaf_parameters.VariableIntParam) + return ok1 } -// GetAbstractParamName this implements interface AbstractParameterType -// basically, this type itself can be used as a parameter of abstract type too -func (m parameterizedTypeSingleIntegerParam[T]) GetAbstractParamName() string { - return m.String() +func (m *parameterizedTypeSingleIntegerParam[T]) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + return []interface{}{m.IntegerOption} } diff --git a/types/parameterized_single_integer_param_type_test.go b/types/parameterized_single_integer_param_type_test.go index 2412c0c..9f54d3f 100644 --- a/types/parameterized_single_integer_param_type_test.go +++ b/types/parameterized_single_integer_param_type_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types_test import ( @@ -5,42 +7,36 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) -// a type to indicate all single integer type. -// helpful in initializing different single type integer type to the same interface -type parameterizedSingleIntegerType interface { - types.Type - WithIntegerOption(param parameter_types.LeafIntParamAbstractType) types.Type -} - func TestParameterizedSingleIntegerType(t *testing.T) { + abstractLeafParam_L1 := leaf_parameters.NewVariableIntParam("L1") + concreteLeafParam_38 := leaf_parameters.NewConcreteIntParam(38) for _, td := range []struct { - name string - typ parameterizedSingleIntegerType - nullability types.Nullability - integerOption parameter_types.LeafIntParamAbstractType - expectedString string - expectedBaseString string - expectedShortString string + name string + typ types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedIsParameterized bool + expectedAbstractParams []interface{} }{ - {"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "varchar?", "varchar", "vchar"}, - {"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "varchar", "varchar", "vchar"}, - {"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "char?", "char", "fchar"}, - {"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "char", "char", "fchar"}, - {"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "fixedbinary?", "fixedbinary", "fbin"}, - {"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "fixedbinary", "fixedbinary", "fbin"}, - {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp?", "precision_timestamp", "prets"}, - {"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp", "precision_timestamp", "prets"}, - {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp_tz?", "precision_timestamp_tz", "pretstz"}, - {"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, parameter_types.LeafIntParamAbstractType("L1"), "precision_timestamp_tz", "precision_timestamp_tz", "pretstz"}, + {"nullable parameterized varchar", &types.ParameterizedVarCharType{IntegerOption: abstractLeafParam_L1}, "varchar?", "varchar", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete varchar", &types.ParameterizedVarCharType{IntegerOption: concreteLeafParam_38}, "varchar?<38>", "varchar<38>", false, nil}, + {"nullable fixChar", &types.ParameterizedFixedCharType{IntegerOption: abstractLeafParam_L1}, "char?", "char", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete fixChar", &types.ParameterizedFixedCharType{IntegerOption: concreteLeafParam_38}, "char?<38>", "char<38>", false, nil}, + {"nullable fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: abstractLeafParam_L1}, "fixedbinary?", "fixedbinary", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: concreteLeafParam_38}, "fixedbinary?<38>", "fixedbinary<38>", false, nil}, + {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: abstractLeafParam_L1}, "precision_timestamp?", "precision_timestamp", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: concreteLeafParam_38}, "precision_timestamp?<38>", "precision_timestamp<38>", false, nil}, + {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: abstractLeafParam_L1}, "precision_timestamp_tz?", "precision_timestamp_tz", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: concreteLeafParam_38}, "precision_timestamp_tz?<38>", "precision_timestamp_tz<38>", false, nil}, } { t.Run(td.name, func(t *testing.T) { - pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability) - require.Equal(t, td.expectedString, pt.String()) - require.Equal(t, td.expectedShortString, pt.ShortString()) - require.True(t, pt.Equals(pt)) + require.Equal(t, td.expectedNullableString, td.typ.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, td.typ.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedIsParameterized, td.typ.HasParameterizedParam()) + require.Equal(t, td.expectedAbstractParams, td.typ.GetParameterizedParams()) }) } } diff --git a/types/parameterized_struct_type.go b/types/parameterized_struct_type.go index 6c66efc..30688cb 100644 --- a/types/parameterized_struct_type.go +++ b/types/parameterized_struct_type.go @@ -1,56 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" "strings" - - "github.com/substrait-io/substrait-go/types/parameter_types" ) -// ParameterizedStructType is a struct having at least one parameter of type ParameterizedAbstractType -// example: Struct. -// If All arguments are concrete they are represented by StructType +// ParameterizedStructType is a parameter type struct +// example: Struct or Struct. type ParameterizedStructType struct { Nullability Nullability TypeVariationRef uint32 - Type []Type + Types []FuncDefArgType } -func (*ParameterizedStructType) isRootRef() {} -func (m *ParameterizedStructType) WithNullability(n Nullability) Type { +func (m *ParameterizedStructType) SetNullability(n Nullability) FuncDefArgType { m.Nullability = n return m } -func (m *ParameterizedStructType) GetType() Type { return m } -func (m *ParameterizedStructType) GetNullability() Nullability { return m.Nullability } -func (m *ParameterizedStructType) GetTypeVariationReference() uint32 { - return m.TypeVariationRef -} -func (m *ParameterizedStructType) Equals(rhs Type) bool { - if o, ok := rhs.(*ParameterizedStructType); ok { - if m.Nullability != o.Nullability || len(m.Type) != len(o.Type) || - m.TypeVariationRef != o.TypeVariationRef { - return false - } - for i := range m.Type { - if !m.Type[i].Equals(o.Type[i]) { - return false - } - } - return true - } - return false -} - -func (m *ParameterizedStructType) ShortString() string { - t := StructType{} - return t.ShortString() -} - func (m *ParameterizedStructType) String() string { sb := strings.Builder{} - for i, typ := range m.Type { + for i, typ := range m.Types { if i != 0 { sb.WriteString(", ") } @@ -58,23 +30,27 @@ func (m *ParameterizedStructType) String() string { } t := StructType{} parameterString := fmt.Sprintf("<%s>", sb.String()) - return fmt.Sprintf("%s%s%s", t.BaseString(), strNullable(m), parameterString) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) } -// GetAbstractParameters returns the abstract parameter names -// this implements interface ParameterizedAbstractType -func (m *ParameterizedStructType) GetAbstractParameters() []parameter_types.AbstractParameterType { - var abstractParams []parameter_types.AbstractParameterType - for _, typ := range m.Type { - if abs, ok := typ.(parameter_types.AbstractParameterType); ok { - abstractParams = append(abstractParams, abs) +func (m *ParameterizedStructType) HasParameterizedParam() bool { + for _, typ := range m.Types { + if typ.HasParameterizedParam() { + return true } } - return abstractParams + return false } -// GetAbstractParamName this implements interface AbstractParameterType -// to indicate ParameterizedStructType itself can be used as a parameter of abstract type too -func (m *ParameterizedStructType) GetAbstractParamName() string { - return m.String() +func (m *ParameterizedStructType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + var abstractParams []interface{} + for _, typ := range m.Types { + if typ.HasParameterizedParam() { + abstractParams = append(abstractParams, typ) + } + } + return abstractParams } diff --git a/types/parameterized_struct_type_test.go b/types/parameterized_struct_type_test.go index 0e70b6f..ac88c6b 100644 --- a/types/parameterized_struct_type_test.go +++ b/types/parameterized_struct_type_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types_test import ( @@ -5,41 +7,35 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) func TestParameterizedStructType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: parameter_types.LeafIntParamAbstractType("P"), - Scale: parameter_types.LeafIntParamAbstractType("S"), + Precision: leaf_parameters.NewVariableIntParam("P"), + Scale: leaf_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} for _, td := range []struct { - name string - types []types.Type - nullability types.Nullability - expectedString string - expectedShortString string - expectedNrAbstractParam int + name string + params []types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} }{ - {"single abstract param", []types.Type{decimalType}, types.NullabilityNullable, "struct?>", "struct", 1}, - {"multiple abstract param", []types.Type{decimalType, int8Type, listType}, types.NullabilityRequired, "struct, i8?, list?>>", "struct", 2}, + {"all parameterized param", []types.FuncDefArgType{decimalType, listType}, "struct?, list?>>", "struct, list?>>", true, []interface{}{decimalType, listType}}, + {"mix parameterized concrete param", []types.FuncDefArgType{decimalType, int8Type, listType}, "struct?, i8?, list?>>", "struct, i8?, list?>>", true, []interface{}{decimalType, listType}}, + {"all concrete param", []types.FuncDefArgType{int8Type, int8Type, int8Type}, "struct?", "struct", false, nil}, } { t.Run(td.name, func(t *testing.T) { - ps := &types.ParameterizedStructType{Type: td.types} - psType := ps.WithNullability(td.nullability) - require.Equal(t, td.expectedString, psType.String()) - require.Equal(t, td.expectedShortString, psType.ShortString()) - require.True(t, psType.Equals(psType)) - - psAbsParamType, ok := psType.(parameter_types.AbstractParameterType) - require.True(t, ok) - require.Equal(t, td.expectedString, psAbsParamType.GetAbstractParamName()) - psAbstractType, ok := psType.(types.ParameterizedAbstractType) - require.True(t, ok) - require.Len(t, psAbstractType.GetAbstractParameters(), td.expectedNrAbstractParam) + pd := &types.ParameterizedStructType{Types: td.params} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) }) } } diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index 4e5abd0..b32f9c4 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -12,7 +12,7 @@ import ( "github.com/alecthomas/participle/v2/lexer" substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" ) var defaultParser *Parser @@ -23,12 +23,12 @@ type TypeExpression struct { func (t TypeExpression) String() string { return t.Expr.String() } -func (t TypeExpression) Type() (types.Type, error) { +func (t TypeExpression) Type() (types.FuncDefArgType, error) { typeDef, ok := t.Expr.(Def) if !ok { return nil, errors.New("type expression doesn't represent type") } - return typeDef.Type() + return typeDef.ArgType() } func (t TypeExpression) MarshalYAML() (interface{}, error) { @@ -95,15 +95,23 @@ func (t *Type) String() string { return t.TypeDef.String() } -func (t *Type) Type() (types.Type, error) { - return t.TypeDef.Type() +func (t *Type) ArgType() (types.FuncDefArgType, error) { + return t.TypeDef.ArgType() +} + +func (t *Type) RetType() (types.Type, error) { + return t.TypeDef.RetType() } type Def interface { String() string ShortType() string - Type() (types.Type, error) + // ArgType indicates argument type + ArgType() (types.FuncDefArgType, error) Optional() bool + // TODO RetType indicates return type + // This should be replaced with TypeDerivation method. Currently it just returns concrete type + RetType() (types.Type, error) } type typename string @@ -133,7 +141,7 @@ func (t *nonParamType) ShortType() string { return types.GetShortTypeName(types.TypeName(t.TypeName)) } -func (t *nonParamType) Type() (types.Type, error) { +func (t *nonParamType) RetType() (types.Type, error) { var n types.Nullability if t.Nullability { n = types.NullabilityNullable @@ -147,6 +155,21 @@ func (t *nonParamType) Type() (types.Type, error) { return nil, err } +func (t *nonParamType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + typ, err := types.SimpleTypeNameToType(types.TypeName(t.TypeName)) + if err != nil { + return nil, err + } + funcArgType := typ.(types.FuncDefArgType) + return funcArgType.SetNullability(n), nil +} + type listType struct { Nullability bool `parser:"'list' @'?'?"` ElemType TypeExpression `parser:"'<' @@ '>'"` @@ -164,7 +187,7 @@ func (l *listType) String() string { func (l *listType) Optional() bool { return false } -func (l *listType) Type() (types.Type, error) { +func (l *listType) RetType() (types.Type, error) { var n types.Nullability if l.Nullability { n = types.NullabilityNullable @@ -172,21 +195,36 @@ func (l *listType) Type() (types.Type, error) { n = types.NullabilityRequired } if t, ok := l.ElemType.Expr.(*Type); ok { - ret, err := t.Type() + ret, err := t.RetType() if err != nil { return nil, err } - if abstractParam, ok1 := ret.(types.ParameterizedAbstractType); ok1 { - return &types.ParameterizedListType{ - Nullability: n, - Type: abstractParam, - }, nil - } return &types.ListType{ Nullability: n, Type: ret, }, nil } + + return nil, substraitgo.ErrNotImplemented +} + +func (l *listType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if l.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + if t, ok := l.ElemType.Expr.(*Type); ok { + ret, err := t.ArgType() + if err != nil { + return nil, err + } + return &types.ParameterizedListType{ + Nullability: n, + Type: ret, + }, nil + } return nil, substraitgo.ErrNotImplemented } @@ -213,61 +251,51 @@ func (p *lengthType) String() string { func (p *lengthType) Optional() bool { return false } -func (p *lengthType) Type() (types.Type, error) { +func (p *lengthType) RetType() (types.Type, error) { var n types.Nullability - - var typ types.Type - var err error - switch t := p.NumericParam.Expr.(type) { - case *IntegerLiteral: - typ, err = getFixedTypeFromConcreteParam(p.TypeName, t) - case *ParamName: - typ, err = getParameterizedTypeSingleParam(p.TypeName, t) - default: + lit, ok := p.NumericParam.Expr.(*IntegerLiteral) + if !ok { return nil, substraitgo.ErrNotImplemented } + + typ, err := types.FixedTypeNameToType(types.TypeName(p.TypeName)) if err != nil { return nil, err } - return typ.WithNullability(n), nil + return typ.WithLength(lit.Value).WithNullability(n), nil } -func getFixedTypeFromConcreteParam(name string, param *IntegerLiteral) (types.Type, error) { - typeName := types.TypeName(name) - switch typeName { - case types.TypeNamePrecisionTimestamp: - precision, err := types.ProtoToTimePrecision(param.Value) - if err != nil { - return nil, err - } - return types.NewPrecisionTimestampType(precision), nil - case types.TypeNamePrecisionTimestampTz: - precision, err := types.ProtoToTimePrecision(param.Value) - if err != nil { - return nil, err - } - return types.NewPrecisionTimestampTzType(precision), nil +func (p *lengthType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + + var leafParam leaf_parameters.LeafParameter + switch t := p.NumericParam.Expr.(type) { + case *IntegerLiteral: + leafParam = leaf_parameters.NewConcreteIntParam(t.Value) + case *ParamName: + leafParam = leaf_parameters.NewVariableIntParam(t.Name) + default: + return nil, substraitgo.ErrNotImplemented } - typ, err := types.FixedTypeNameToType(typeName) + typ, err := getParameterizedTypeSingleParam(p.TypeName, leafParam, n) if err != nil { return nil, err } - return typ.WithLength(param.Value), nil + return typ, nil } -func getParameterizedTypeSingleParam(typeName string, param *ParamName) (types.Type, error) { - intParam := parameter_types.LeafIntParamAbstractType(param.Name) +func getParameterizedTypeSingleParam(typeName string, leafParam leaf_parameters.LeafParameter, n types.Nullability) (types.FuncDefArgType, error) { switch types.TypeName(typeName) { case types.TypeNameVarChar: - return types.ParameterizedVarCharType{IntegerOption: intParam}, nil + return &types.ParameterizedVarCharType{IntegerOption: leafParam, Nullability: n}, nil case types.TypeNameFixedChar: - return types.ParameterizedFixedCharType{IntegerOption: intParam}, nil + return &types.ParameterizedFixedCharType{IntegerOption: leafParam, Nullability: n}, nil case types.TypeNameFixedBinary: - return types.ParameterizedFixedBinaryType{IntegerOption: intParam}, nil + return &types.ParameterizedFixedBinaryType{IntegerOption: leafParam, Nullability: n}, nil case types.TypeNamePrecisionTimestamp: - return types.ParameterizedPrecisionTimestampType{IntegerOption: intParam}, nil + return &types.ParameterizedPrecisionTimestampType{IntegerOption: leafParam, Nullability: n}, nil case types.TypeNamePrecisionTimestampTz: - return types.ParameterizedPrecisionTimestampTzType{IntegerOption: intParam}, nil + return &types.ParameterizedPrecisionTimestampTzType{IntegerOption: leafParam, Nullability: n}, nil default: return nil, substraitgo.ErrNotImplemented } @@ -291,49 +319,55 @@ func (d *decimalType) String() string { func (d *decimalType) Optional() bool { return d.Nullability } -func (d *decimalType) Type() (types.Type, error) { +func (d *decimalType) ArgType() (types.FuncDefArgType, error) { var n types.Nullability if d.Nullability { n = types.NullabilityNullable } else { n = types.NullabilityRequired } - pi, isPrecisionConcrete := d.Precision.Expr.(*IntegerLiteral) - si, isScaleConcrete := d.Scale.Expr.(*IntegerLiteral) - if isPrecisionConcrete && isScaleConcrete { - // concrete decimal param - return &types.DecimalType{ - Nullability: n, - Precision: parameter_types.LeafIntParamConcreteType(pi.Value), - Scale: parameter_types.LeafIntParamConcreteType(si.Value), - }, nil + var precision leaf_parameters.LeafParameter + if pi, ok := d.Precision.Expr.(*IntegerLiteral); ok { + precision = leaf_parameters.NewConcreteIntParam(pi.Value) + } else { + ps := d.Precision.Expr.(*ParamName) + precision = leaf_parameters.NewVariableIntParam(ps.String()) } - // there is at least one abstract param, so it is parameterized type - - ps, isPrecisionAbstract := d.Precision.Expr.(*ParamName) - ss, isScaleAbstract := d.Scale.Expr.(*ParamName) - if isPrecisionAbstract && isScaleAbstract { - // both abstract param - return &types.ParameterizedDecimalType{ - Nullability: n, - Precision: parameter_types.LeafIntParamAbstractType(ps.Name), - Scale: parameter_types.LeafIntParamAbstractType(ss.Name), - }, nil + var scale leaf_parameters.LeafParameter + if si, ok := d.Scale.Expr.(*IntegerLiteral); ok { + scale = leaf_parameters.NewConcreteIntParam(si.Value) + } else { + ss := d.Scale.Expr.(*ParamName) + scale = leaf_parameters.NewVariableIntParam(ss.String()) } - // one abstract and one concrete - if isPrecisionConcrete { - return &types.ParameterizedDecimalType{ - Nullability: n, - Precision: parameter_types.LeafIntParamConcreteType(pi.Value), - Scale: parameter_types.LeafIntParamAbstractType(ss.Name), - }, nil - } return &types.ParameterizedDecimalType{ Nullability: n, - Precision: parameter_types.LeafIntParamAbstractType(ps.Name), - Scale: parameter_types.LeafIntParamConcreteType(si.Value), + Precision: precision, + Scale: scale, + }, nil +} + +func (d *decimalType) RetType() (types.Type, error) { + var n types.Nullability + if d.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + p, ok := d.Precision.Expr.(*IntegerLiteral) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + s, ok := d.Scale.Expr.(*IntegerLiteral) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + return &types.DecimalType{ + Nullability: n, + Precision: p.Value, + Scale: s.Value, }, nil } @@ -363,7 +397,7 @@ func (s *structType) String() string { func (t *structType) Optional() bool { return t.Nullability } -func (t *structType) Type() (types.Type, error) { +func (t *structType) RetType() (types.Type, error) { var n types.Nullability if t.Nullability { n = types.NullabilityNullable @@ -372,25 +406,15 @@ func (t *structType) Type() (types.Type, error) { } var err error typeList := make([]types.Type, len(t.Types)) - anyAbstractParamPresent := false for i, typ := range t.Types { tp, ok := typ.Expr.(*Type) if !ok { return nil, substraitgo.ErrNotImplemented } - if typeList[i], err = tp.Type(); err != nil { + if typeList[i], err = tp.RetType(); err != nil { return nil, err } - if _, ok1 := typeList[i].(types.ParameterizedAbstractType); ok1 { - anyAbstractParamPresent = true - } - } - if anyAbstractParamPresent { - return &types.ParameterizedStructType{ - Nullability: n, - Type: typeList, - }, nil } return &types.StructType{ Nullability: n, @@ -398,6 +422,31 @@ func (t *structType) Type() (types.Type, error) { }, nil } +func (t *structType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + var err error + typeList := make([]types.FuncDefArgType, len(t.Types)) + for i, typ := range t.Types { + tp, ok := typ.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + + if typeList[i], err = tp.ArgType(); err != nil { + return nil, err + } + } + return &types.ParameterizedStructType{ + Nullability: n, + Types: typeList, + }, nil +} + type mapType struct { Nullability bool `parser:"'map' @'?'?"` Key TypeExpression `parser:"'<' @@"` @@ -416,7 +465,7 @@ func (m *mapType) String() string { func (m *mapType) Optional() bool { return m.Nullability } -func (m *mapType) Type() (types.Type, error) { +func (m *mapType) RetType() (types.Type, error) { var n types.Nullability if m.Nullability { n = types.NullabilityNullable @@ -434,32 +483,51 @@ func (m *mapType) Type() (types.Type, error) { return nil, substraitgo.ErrNotImplemented } - key, err := k.Type() + key, err := k.RetType() if err != nil { return nil, err } - value, err := v.Type() + value, err := v.RetType() if err != nil { return nil, err } + return &types.MapType{ + Key: key, + Value: value, + Nullability: n, + }, nil +} + +func (m *mapType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if m.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + + k, ok := m.Key.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented + } - anyAbstractParamPresent := false - if _, ok1 := key.(types.ParameterizedAbstractType); ok1 { - anyAbstractParamPresent = true + v, ok := m.Value.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented } - if _, ok1 := value.(types.ParameterizedAbstractType); ok1 { - anyAbstractParamPresent = true + + key, err := k.ArgType() + if err != nil { + return nil, err } - if anyAbstractParamPresent { - return &types.ParameterizedMapType{ - Key: key, - Value: value, - Nullability: n, - }, nil + + value, err := v.ArgType() + if err != nil { + return nil, err } - return &types.MapType{ + return &types.ParameterizedMapType{ Key: key, Value: value, Nullability: n, @@ -489,7 +557,7 @@ func (t anyType) ShortType() string { return string(t.TypeName) } -func (t anyType) Type() (types.Type, error) { +func (t anyType) ArgType() (types.FuncDefArgType, error) { var n types.Nullability if t.Nullability { n = types.NullabilityNullable @@ -503,6 +571,10 @@ func (t anyType) Type() (types.Type, error) { return types.AnyType{Name: typeName, Nullability: n}, nil } +func (t anyType) RetType() (types.Type, error) { + panic("any type can't be in return type") +} + var ( def = lexer.MustSimple([]lexer.SimpleRule{ {Name: "whitespace", Pattern: `[ \t]+`}, diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index a33abe7..ddc600b 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -9,42 +9,49 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parameter_types" + "github.com/substrait-io/substrait-go/types/leaf_parameters" "github.com/substrait-io/substrait-go/types/parser" ) func TestParser(t *testing.T) { + parameterLeaf_L1 := leaf_parameters.NewVariableIntParam("L1") + parameterLeaf_P := leaf_parameters.NewVariableIntParam("P") + parameterLeaf_S := leaf_parameters.NewVariableIntParam("S") + concreteLeaf_5 := leaf_parameters.NewConcreteIntParam(5) + concreteLeaf_38 := leaf_parameters.NewConcreteIntParam(38) + concreteLeaf_10 := leaf_parameters.NewConcreteIntParam(10) + concreteLeaf_EMinus5 := leaf_parameters.NewConcreteIntParam(int32(types.PrecisionEMinus5Seconds)) tests := []struct { expr string expected string shortName string - expectedTyp types.Type + expectedTyp types.FuncDefArgType }{ {"2", "2", "", nil}, {"-2", "-2", "", nil}, {"i16?", "i16?", "i16", &types.Int16Type{Nullability: types.NullabilityNullable}}, {"boolean", "boolean", "bool", &types.BooleanType{Nullability: types.NullabilityRequired}}, - {"fixedchar<5>", "fixedchar<5>", "fchar", &types.FixedCharType{Length: 5}}, - {"decimal<10,5>", "decimal<10,5>", "dec", &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}}, - {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, - {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, - {"struct", "struct", "struct", &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, - {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, - {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, - {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}, - {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: types.PrecisionEMinus5Seconds}}}, - {"varchar", "varchar", "vchar", types.ParameterizedVarCharType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, - {"fixedchar", "fixedchar", "fchar", types.ParameterizedFixedCharType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, - {"fixedbinary", "fixedbinary", "fbin", types.ParameterizedFixedBinaryType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, - {"precision_timestamp", "precision_timestamp", "prets", types.ParameterizedPrecisionTimestampType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, - {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", types.ParameterizedPrecisionTimestampTzType{IntegerOption: parameter_types.LeafIntParamAbstractType("L1")}}, - {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}}, - {"decimal<38,S>", "decimal<38,S>", "dec", &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamConcreteType(38), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}}, + {"fixedchar<5>", "fixedchar<5>", "fchar", &types.ParameterizedFixedCharType{IntegerOption: concreteLeaf_5}}, + {"decimal<10,5>", "decimal<10,5>", "dec", &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityRequired}}, + {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"list?>", "list?>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"struct", "struct", "struct", &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, + {"map>", "map>", "map", &types.ParameterizedMapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, + {"map?>", "map?>", "map", &types.ParameterizedMapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: concreteLeaf_EMinus5}}, + {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: concreteLeaf_EMinus5}}, + {"varchar", "varchar", "vchar", &types.ParameterizedVarCharType{IntegerOption: parameterLeaf_L1}}, + {"fixedchar", "fixedchar", "fchar", &types.ParameterizedFixedCharType{IntegerOption: parameterLeaf_L1}}, + {"fixedbinary", "fixedbinary", "fbin", &types.ParameterizedFixedBinaryType{IntegerOption: parameterLeaf_L1}}, + {"precision_timestamp", "precision_timestamp", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: parameterLeaf_L1}}, + {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: parameterLeaf_L1}}, + {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: parameterLeaf_S, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}}, + {"decimal<38,S>", "decimal<38,S>", "dec", &types.ParameterizedDecimalType{Precision: concreteLeaf_38, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}}, {"any", "any", "any", types.AnyType{Nullability: types.NullabilityRequired}}, {"any1?", "any1?", "any", types.AnyType{Nullability: types.NullabilityNullable}}, - {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, - {"struct>, i16>", "struct>, i16>", "struct", &types.ParameterizedStructType{Type: []types.Type{&types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Nullability: types.NullabilityNullable}, &types.Int16Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, - {"map, i16>", "map, i16>", "map", &types.ParameterizedMapType{Key: &types.ParameterizedDecimalType{Precision: parameter_types.LeafIntParamAbstractType("P"), Scale: parameter_types.LeafIntParamAbstractType("S"), Nullability: types.NullabilityRequired}, Value: &types.Int16Type{Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"struct>, i16>", "struct>, i16>", "struct", &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityNullable}, &types.Int16Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, + {"map, i16>", "map, i16>", "map", &types.ParameterizedMapType{Key: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Value: &types.Int16Type{Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, } p, err := parser.New() @@ -57,10 +64,9 @@ func TestParser(t *testing.T) { assert.Equal(t, td.expected, d.Expr.String()) if td.shortName != "" { assert.Equal(t, td.shortName, d.Expr.(*parser.Type).ShortType()) - typ, err := d.Expr.(*parser.Type).Type() + typ, err := d.Expr.(*parser.Type).ArgType() assert.NoError(t, err) assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(typ)) - assert.True(t, td.expectedTyp.Equals(typ)) } }) } diff --git a/types/precison_timestamp_types.go b/types/precison_timestamp_types.go index e88d656..0512ced 100644 --- a/types/precison_timestamp_types.go +++ b/types/precison_timestamp_types.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( diff --git a/types/precison_timestamp_types_test.go b/types/precison_timestamp_types_test.go index d606be9..687fe34 100644 --- a/types/precison_timestamp_types_test.go +++ b/types/precison_timestamp_types_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( diff --git a/types/types.go b/types/types.go index 689629f..c68e9ef 100644 --- a/types/types.go +++ b/types/types.go @@ -12,7 +12,6 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/proto" - "github.com/substrait-io/substrait-go/types/parameter_types" ) type Version = proto.Version @@ -186,7 +185,7 @@ type ( // TypeFromProto returns the appropriate Type object from a protobuf // type message. -func TypeFromProto(t *proto.Type) FuncArgType { +func TypeFromProto(t *proto.Type) Type { switch t := t.Kind.(type) { case *proto.Type_Bool: return &BooleanType{ @@ -272,26 +271,26 @@ func TypeFromProto(t *proto.Type) FuncArgType { return &FixedBinaryType{ Nullability: t.FixedBinary.Nullability, TypeVariationRef: t.FixedBinary.TypeVariationReference, - Length: parameter_types.LeafIntParamConcreteType(t.FixedBinary.Length), + Length: t.FixedBinary.Length, } case *proto.Type_FixedChar_: return &FixedCharType{ Nullability: t.FixedChar.Nullability, TypeVariationRef: t.FixedChar.TypeVariationReference, - Length: parameter_types.LeafIntParamConcreteType(t.FixedChar.Length), + Length: t.FixedChar.Length, } case *proto.Type_Varchar: return &VarCharType{ Nullability: t.Varchar.Nullability, TypeVariationRef: t.Varchar.TypeVariationReference, - Length: parameter_types.LeafIntParamConcreteType(t.Varchar.Length), + Length: t.Varchar.Length, } case *proto.Type_Decimal_: return &DecimalType{ Nullability: t.Decimal.Nullability, TypeVariationRef: t.Decimal.TypeVariationReference, - Scale: parameter_types.LeafIntParamConcreteType(t.Decimal.Scale), - Precision: parameter_types.LeafIntParamConcreteType(t.Decimal.Precision), + Scale: t.Decimal.Scale, + Precision: t.Decimal.Precision, } case *proto.Type_Struct_: fields := make([]Type, len(t.Struct.Types)) @@ -356,14 +355,10 @@ type ( fmt.Stringer } - // FuncArgType this represents a type which can be a function argument - FuncArgType interface { - FuncArg - Type - } // Type corresponds to the proto.Type message and represents - // a specific type. + // a specific type. These are types which can be present in plan (are serializable) Type interface { + FuncArg isRootRef() fmt.Stringer ShortString() string @@ -376,21 +371,24 @@ type ( WithNullability(Nullability) Type } - // ParameterizedConcreteType this represents a concrete type with parameters - ParameterizedConcreteType interface { + // CompositeType this represents a concrete type having components + CompositeType interface { Type ParameterString() string BaseString() string } - // ParameterizedAbstractType this represents a type which has at least one abstract parameter - ParameterizedAbstractType interface { - Type - GetAbstractParameters() []parameter_types.AbstractParameterType + // FuncDefArgType this represents a type used in function argument + // These type can't be present in plan (not serializable) + FuncDefArgType interface { + fmt.Stringer + SetNullability(Nullability) FuncDefArgType + HasParameterizedParam() bool + GetParameterizedParams() []interface{} } FixedType interface { - ParameterizedConcreteType + CompositeType WithLength(int32) FixedType } ) @@ -482,19 +480,19 @@ func TypeToProto(t Type) *proto.Type { case *FixedCharType: return &proto.Type{Kind: &proto.Type_FixedChar_{ FixedChar: &proto.Type_FixedChar{ - Length: t.Length.ToProtoVal(), + Length: t.Length, Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *VarCharType: return &proto.Type{Kind: &proto.Type_Varchar{ Varchar: &proto.Type_VarChar{ - Length: t.Length.ToProtoVal(), + Length: t.Length, Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *FixedBinaryType: return &proto.Type{Kind: &proto.Type_FixedBinary_{ FixedBinary: &proto.Type_FixedBinary{ - Length: t.Length.ToProtoVal(), + Length: t.Length, Nullability: t.Nullability, TypeVariationReference: t.TypeVariationRef}}} case *DecimalType: @@ -564,7 +562,11 @@ var shortNames = map[reflect.Type]string{ } func strNullable(t Type) string { - if t.GetNullability() == NullabilityNullable { + return strFromNullability(t.GetNullability()) +} + +func strFromNullability(nullability Nullability) string { + if nullability == NullabilityNullable { return "?" } return "" @@ -618,6 +620,21 @@ func (s *PrimitiveType[T]) String() string { return reflect.TypeOf(z).Elem().Name() + strNullable(s) } +func (s *PrimitiveType[T]) HasParameterizedParam() bool { + // primitive type doesn't have abstract parameters + return false +} + +func (s *PrimitiveType[T]) GetParameterizedParams() []interface{} { + // primitive type doesn't have any abstract parameters + return nil +} + +func (s *PrimitiveType[T]) SetNullability(n Nullability) FuncDefArgType { + s.Nullability = n + return s +} + // create type aliases to the generic structs type ( BooleanType = PrimitiveType[bool] @@ -651,7 +668,7 @@ type ( type FixedLenType[T FixedChar | VarChar | FixedBinary] struct { Nullability Nullability TypeVariationRef uint32 - Length parameter_types.LeafIntParamConcreteType + Length int32 } func (*FixedLenType[T]) isRootRef() {} @@ -700,7 +717,7 @@ func (s *FixedLenType[T]) BaseString() string { func (s *FixedLenType[T]) WithLength(length int32) FixedType { out := *s - out.Length = parameter_types.LeafIntParamConcreteType(length) + out.Length = length return &out } @@ -708,7 +725,7 @@ func (s *FixedLenType[T]) WithLength(length int32) FixedType { type DecimalType struct { Nullability Nullability TypeVariationRef uint32 - Scale, Precision parameter_types.LeafIntParamConcreteType + Scale, Precision int32 } func (*DecimalType) isRootRef() {} @@ -738,7 +755,7 @@ func (s *DecimalType) ToProtoFuncArg() *proto.FunctionArgument { func (s *DecimalType) ToProto() *proto.Type { return &proto.Type{Kind: &proto.Type_Decimal_{ Decimal: &proto.Type_Decimal{ - Scale: s.Scale.ToProtoVal(), Precision: s.Precision.ToProtoVal(), + Scale: s.Scale, Precision: s.Precision, Nullability: s.Nullability, TypeVariationReference: s.TypeVariationRef}}} } From 7beaaebb6668bb20808900b04e99847ad245e9ad Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Thu, 12 Sep 2024 15:30:23 +0530 Subject: [PATCH 6/6] Address review comments --- extensions/simple_extension.go | 8 ++-- extensions/variants.go | 25 ++++++------- extensions/variants_test.go | 8 ++-- .../concrete_int_param.go | 6 +-- .../integer_parameter_type.go} | 8 ++-- .../integer_parameter_type_test.go} | 14 +++---- .../variable_int_param.go | 6 +-- types/parameterized_decimal_type.go | 14 +++---- types/parameterized_decimal_type_test.go | 14 +++---- types/parameterized_list_type_test.go | 6 +-- types/parameterized_map_type_test.go | 6 +-- ...parameterized_single_integer_param_type.go | 37 ++++++++----------- ...eterized_single_integer_param_type_test.go | 6 +-- types/parameterized_struct_type_test.go | 6 +-- types/parser/type_parser.go | 22 +++++------ types/parser/type_parser_test.go | 16 ++++---- types/types.go | 19 +++++++--- 17 files changed, 110 insertions(+), 111 deletions(-) rename types/{leaf_parameters => integer_parameters}/concrete_int_param.go (76%) rename types/{leaf_parameters/leaf_parameter_type.go => integer_parameters/integer_parameter_type.go} (65%) rename types/{leaf_parameters/leaf_parameter_type_test.go => integer_parameters/integer_parameter_type_test.go} (63%) rename types/{leaf_parameters => integer_parameters}/variable_int_param.go (78%) diff --git a/extensions/simple_extension.go b/extensions/simple_extension.go index f055d73..d9cc4ec 100644 --- a/extensions/simple_extension.go +++ b/extensions/simple_extension.go @@ -57,7 +57,7 @@ type TypeVariation struct { type Argument interface { toTypeString() string - marker() // unexported marker method + argumentMarker() // unexported marker method } type EnumArg struct { @@ -70,7 +70,7 @@ func (EnumArg) toTypeString() string { return "req" } -func (v EnumArg) marker() {} +func (v EnumArg) argumentMarker() {} type ValueArg struct { Name string `yaml:",omitempty"` @@ -83,7 +83,7 @@ func (v ValueArg) toTypeString() string { return v.Value.Expr.(*parser.Type).ShortType() } -func (v ValueArg) marker() {} +func (v ValueArg) argumentMarker() {} type TypeArg struct { Name string `yaml:",omitempty"` @@ -93,7 +93,7 @@ type TypeArg struct { func (TypeArg) toTypeString() string { return "type" } -func (v TypeArg) marker() {} +func (v TypeArg) argumentMarker() {} type ArgumentList []Argument diff --git a/extensions/variants.go b/extensions/variants.go index 3ee85d7..31752fa 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -8,7 +8,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -379,23 +379,20 @@ func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowTy // HasSyncParams This API returns if params share a leaf param name func HasSyncParams(params []types.FuncDefArgType) bool { - // get list of parameters from Abstract parameter type - // if any of the parameter is common, it indicates parameters are same across parameters + // if any of the leaf parameters are same, it indicates parameters are same across parameters existingParamMap := make(map[string]bool) for _, p := range params { if !p.HasParameterizedParam() { // not a type which contains abstract parameters, so continue continue } - // get list of parameters for each abstract parameter type - // note, this can be more than one parameter because of nested abstract types - // e.g. Decimal or List, VARCHAR>> + // get list of parameterized parameters + // parameterized param can be a Leaf or another type. If another type we recurse to find leaf abstractParams := p.GetParameterizedParams() var leafParams []string for _, abstractParam := range abstractParams { - leafParams = append(leafParams, getLeafAbstractParams(abstractParam)...) + leafParams = append(leafParams, getLeafParameterizedParams(abstractParam)...) } - // all leaf params for this parameters are found // if map contains any of the leaf params, parameters are synced for _, leafParam := range leafParams { if _, ok := existingParamMap[leafParam]; ok { @@ -412,23 +409,23 @@ func HasSyncParams(params []types.FuncDefArgType) bool { return false } -// from a parameter of abstract type, get the leaf parameters -// an abstract parameter can be a leaf type or a parameterized type itself +// from a parameterized type, get the leaf parameters +// an parameterized param can be a leaf type (e.g. P) or a parameterized type (e.g. VARCHAR) itself // if it is a leaf type, its param name is returned // if it is parameterized type, leaf type is found recursively -func getLeafAbstractParams(abstractTypes interface{}) []string { - if leaf, ok := abstractTypes.(leaf_parameters.LeafParameter); ok { +func getLeafParameterizedParams(abstractTypes interface{}) []string { + if leaf, ok := abstractTypes.(integer_parameters.IntegerParameter); ok { return []string{leaf.String()} } // if it is not a leaf type recurse if pat, ok := abstractTypes.(types.FuncDefArgType); ok { var outLeafParams []string for _, p := range pat.GetParameterizedParams() { - childLeafParams := getLeafAbstractParams(p) + childLeafParams := getLeafParameterizedParams(p) outLeafParams = append(outLeafParams, childLeafParams...) } return outLeafParams } - // for leaf type, return the param name + // invalid type panic("invalid non-leaf, non-parameterized type param") } diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 88b4a23..79280c9 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -70,9 +70,9 @@ func TestEvaluateTypeExpression(t *testing.T) { func TestHasSyncParams(t *testing.T) { - apt_P := leaf_parameters.NewVariableIntParam("P") - apt_Q := leaf_parameters.NewVariableIntParam("Q") - cpt_38 := leaf_parameters.NewConcreteIntParam(38) + apt_P := integer_parameters.NewVariableIntParam("P") + apt_Q := integer_parameters.NewVariableIntParam("Q") + cpt_38 := integer_parameters.NewConcreteIntParam(38) fct_P := &types.ParameterizedFixedCharType{IntegerOption: apt_P} fct_Q := &types.ParameterizedFixedCharType{IntegerOption: apt_Q} diff --git a/types/leaf_parameters/concrete_int_param.go b/types/integer_parameters/concrete_int_param.go similarity index 76% rename from types/leaf_parameters/concrete_int_param.go rename to types/integer_parameters/concrete_int_param.go index 30f1e32..dd3765a 100644 --- a/types/leaf_parameters/concrete_int_param.go +++ b/types/integer_parameters/concrete_int_param.go @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -package leaf_parameters +package integer_parameters import "fmt" @@ -9,12 +9,12 @@ import "fmt" // DECIMAL --> 0 Is an ConcreteIntParam but P not type ConcreteIntParam int32 -func NewConcreteIntParam(v int32) LeafParameter { +func NewConcreteIntParam(v int32) IntegerParameter { m := ConcreteIntParam(v) return &m } -func (m *ConcreteIntParam) IsCompatible(o LeafParameter) bool { +func (m *ConcreteIntParam) IsCompatible(o IntegerParameter) bool { if t, ok := o.(*ConcreteIntParam); ok { return t == m } diff --git a/types/leaf_parameters/leaf_parameter_type.go b/types/integer_parameters/integer_parameter_type.go similarity index 65% rename from types/leaf_parameters/leaf_parameter_type.go rename to types/integer_parameters/integer_parameter_type.go index ba2692b..8c562a1 100644 --- a/types/leaf_parameters/leaf_parameter_type.go +++ b/types/integer_parameters/integer_parameter_type.go @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 -package leaf_parameters +package integer_parameters import "fmt" -// LeafParameter represents a parameter type +// IntegerParameter represents a parameter type // parameter can of concrete (38) or abstract type (P) // or another parameterized type like VARCHAR<"L1"> -type LeafParameter interface { +type IntegerParameter interface { // IsCompatible is type compatible with other // compatible is other can be used in place of this type - IsCompatible(other LeafParameter) bool + IsCompatible(other IntegerParameter) bool fmt.Stringer } diff --git a/types/leaf_parameters/leaf_parameter_type_test.go b/types/integer_parameters/integer_parameter_type_test.go similarity index 63% rename from types/leaf_parameters/leaf_parameter_type_test.go rename to types/integer_parameters/integer_parameter_type_test.go index 5a7a61f..0a47de2 100644 --- a/types/leaf_parameters/leaf_parameter_type_test.go +++ b/types/integer_parameters/integer_parameter_type_test.go @@ -1,26 +1,26 @@ // SPDX-License-Identifier: Apache-2.0 -package leaf_parameters_test +package integer_parameters_test import ( "testing" "github.com/stretchr/testify/require" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestConcreteParameterType(t *testing.T) { - concreteType1 := leaf_parameters.ConcreteIntParam(1) + concreteType1 := integer_parameters.ConcreteIntParam(1) require.Equal(t, "1", concreteType1.String()) } func TestLeafParameterType(t *testing.T) { - var concreteType1, concreteType2, abstractType1 leaf_parameters.LeafParameter + var concreteType1, concreteType2, abstractType1 integer_parameters.IntegerParameter - concreteType1 = leaf_parameters.NewConcreteIntParam(1) - concreteType2 = leaf_parameters.NewConcreteIntParam(2) + concreteType1 = integer_parameters.NewConcreteIntParam(1) + concreteType2 = integer_parameters.NewConcreteIntParam(2) - abstractType1 = leaf_parameters.NewVariableIntParam("P") + abstractType1 = integer_parameters.NewVariableIntParam("P") // verify string val require.Equal(t, "1", concreteType1.String()) diff --git a/types/leaf_parameters/variable_int_param.go b/types/integer_parameters/variable_int_param.go similarity index 78% rename from types/leaf_parameters/variable_int_param.go rename to types/integer_parameters/variable_int_param.go index abc3612..76fd5c2 100644 --- a/types/leaf_parameters/variable_int_param.go +++ b/types/integer_parameters/variable_int_param.go @@ -1,18 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 -package leaf_parameters +package integer_parameters // VariableIntParam represents an integer parameter for a parameterized type // Example: VARCHAR(L1) -> L1 is an VariableIntParam // DECIMAL --> P Is an VariableIntParam type VariableIntParam string -func NewVariableIntParam(s string) LeafParameter { +func NewVariableIntParam(s string) IntegerParameter { m := VariableIntParam(s) return &m } -func (m *VariableIntParam) IsCompatible(o LeafParameter) bool { +func (m *VariableIntParam) IsCompatible(o IntegerParameter) bool { switch o.(type) { case *VariableIntParam, *ConcreteIntParam: return true diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go index 0e31c82..2e0fb37 100644 --- a/types/parameterized_decimal_type.go +++ b/types/parameterized_decimal_type.go @@ -5,7 +5,7 @@ package types import ( "fmt" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) // ParameterizedDecimalType is a decimal type which to hold function arguments @@ -13,8 +13,8 @@ import ( type ParameterizedDecimalType struct { Nullability Nullability TypeVariationRef uint32 - Precision leaf_parameters.LeafParameter - Scale leaf_parameters.LeafParameter + Precision integer_parameters.IntegerParameter + Scale integer_parameters.IntegerParameter } func (m *ParameterizedDecimalType) SetNullability(n Nullability) FuncDefArgType { @@ -29,8 +29,8 @@ func (m *ParameterizedDecimalType) String() string { } func (m *ParameterizedDecimalType) HasParameterizedParam() bool { - _, ok1 := m.Precision.(*leaf_parameters.VariableIntParam) - _, ok2 := m.Scale.(*leaf_parameters.VariableIntParam) + _, ok1 := m.Precision.(*integer_parameters.VariableIntParam) + _, ok2 := m.Scale.(*integer_parameters.VariableIntParam) return ok1 || ok2 } @@ -39,10 +39,10 @@ func (m *ParameterizedDecimalType) GetParameterizedParams() []interface{} { return nil } var params []interface{} - if p, ok := m.Precision.(*leaf_parameters.VariableIntParam); ok { + if p, ok := m.Precision.(*integer_parameters.VariableIntParam); ok { params = append(params, p) } - if p, ok := m.Scale.(*leaf_parameters.VariableIntParam); ok { + if p, ok := m.Scale.(*integer_parameters.VariableIntParam); ok { params = append(params, p) } return params diff --git a/types/parameterized_decimal_type_test.go b/types/parameterized_decimal_type_test.go index 63afc09..5729dee 100644 --- a/types/parameterized_decimal_type_test.go +++ b/types/parameterized_decimal_type_test.go @@ -7,18 +7,18 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestParameterizedDecimalType(t *testing.T) { - precision_P := leaf_parameters.NewVariableIntParam("P") - scale_S := leaf_parameters.NewVariableIntParam("S") - precision_38 := leaf_parameters.NewConcreteIntParam(38) - scale_5 := leaf_parameters.NewConcreteIntParam(5) + precision_P := integer_parameters.NewVariableIntParam("P") + scale_S := integer_parameters.NewVariableIntParam("S") + precision_38 := integer_parameters.NewConcreteIntParam(38) + scale_5 := integer_parameters.NewConcreteIntParam(5) for _, td := range []struct { name string - precision leaf_parameters.LeafParameter - scale leaf_parameters.LeafParameter + precision integer_parameters.IntegerParameter + scale integer_parameters.IntegerParameter expectedNullableString string expectedNullableRequiredString string expectedHasParameterizedParam bool diff --git a/types/parameterized_list_type_test.go b/types/parameterized_list_type_test.go index 82166c2..da54387 100644 --- a/types/parameterized_list_type_test.go +++ b/types/parameterized_list_type_test.go @@ -7,13 +7,13 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestParameterizedListType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: leaf_parameters.NewVariableIntParam("P"), - Scale: leaf_parameters.NewVariableIntParam("S"), + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } int8Type := &types.Int8Type{} diff --git a/types/parameterized_map_type_test.go b/types/parameterized_map_type_test.go index 42cb04e..62378ba 100644 --- a/types/parameterized_map_type_test.go +++ b/types/parameterized_map_type_test.go @@ -7,13 +7,13 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestParameterizedMapType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: leaf_parameters.NewVariableIntParam("P"), - Scale: leaf_parameters.NewVariableIntParam("S"), + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} diff --git a/types/parameterized_single_integer_param_type.go b/types/parameterized_single_integer_param_type.go index d5c215b..4582aea 100644 --- a/types/parameterized_single_integer_param_type.go +++ b/types/parameterized_single_integer_param_type.go @@ -4,15 +4,20 @@ package types import ( "fmt" + "reflect" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) +type singleIntegerParamType interface { + BaseString() string +} + // parameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter -type parameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct { +type parameterizedTypeSingleIntegerParam[T singleIntegerParamType] struct { Nullability Nullability TypeVariationRef uint32 - IntegerOption leaf_parameters.LeafParameter + IntegerOption integer_parameters.IntegerParameter } func (m *parameterizedTypeSingleIntegerParam[T]) SetNullability(n Nullability) FuncDefArgType { @@ -29,29 +34,17 @@ func (m *parameterizedTypeSingleIntegerParam[T]) parameterString() string { } func (m *parameterizedTypeSingleIntegerParam[T]) baseString() string { - switch any(m).(type) { - case *ParameterizedVarCharType: - t := VarCharType{} - return t.BaseString() - case *ParameterizedFixedCharType: - t := FixedCharType{} - return t.BaseString() - case *ParameterizedFixedBinaryType: - t := FixedBinaryType{} - return t.BaseString() - case *ParameterizedPrecisionTimestampType: - t := PrecisionTimestampType{} - return t.BaseString() - case *ParameterizedPrecisionTimestampTzType: - t := PrecisionTimestampTzType{} - return t.BaseString() - default: - panic("unknown type") + var t T + tType := reflect.TypeOf(t) + if tType.Kind() == reflect.Ptr { + tType = tType.Elem() } + newInstance := reflect.New(tType).Interface().(T) + return newInstance.BaseString() } func (m *parameterizedTypeSingleIntegerParam[T]) HasParameterizedParam() bool { - _, ok1 := m.IntegerOption.(*leaf_parameters.VariableIntParam) + _, ok1 := m.IntegerOption.(*integer_parameters.VariableIntParam) return ok1 } diff --git a/types/parameterized_single_integer_param_type_test.go b/types/parameterized_single_integer_param_type_test.go index 9f54d3f..1c98491 100644 --- a/types/parameterized_single_integer_param_type_test.go +++ b/types/parameterized_single_integer_param_type_test.go @@ -7,12 +7,12 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestParameterizedSingleIntegerType(t *testing.T) { - abstractLeafParam_L1 := leaf_parameters.NewVariableIntParam("L1") - concreteLeafParam_38 := leaf_parameters.NewConcreteIntParam(38) + abstractLeafParam_L1 := integer_parameters.NewVariableIntParam("L1") + concreteLeafParam_38 := integer_parameters.NewConcreteIntParam(38) for _, td := range []struct { name string typ types.FuncDefArgType diff --git a/types/parameterized_struct_type_test.go b/types/parameterized_struct_type_test.go index ac88c6b..68c3bf4 100644 --- a/types/parameterized_struct_type_test.go +++ b/types/parameterized_struct_type_test.go @@ -7,13 +7,13 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) func TestParameterizedStructType(t *testing.T) { decimalType := &types.ParameterizedDecimalType{ - Precision: leaf_parameters.NewVariableIntParam("P"), - Scale: leaf_parameters.NewVariableIntParam("S"), + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), Nullability: types.NullabilityRequired, } int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index b32f9c4..4e05aa0 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -12,7 +12,7 @@ import ( "github.com/alecthomas/participle/v2/lexer" substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) var defaultParser *Parser @@ -268,12 +268,12 @@ func (p *lengthType) RetType() (types.Type, error) { func (p *lengthType) ArgType() (types.FuncDefArgType, error) { var n types.Nullability - var leafParam leaf_parameters.LeafParameter + var leafParam integer_parameters.IntegerParameter switch t := p.NumericParam.Expr.(type) { case *IntegerLiteral: - leafParam = leaf_parameters.NewConcreteIntParam(t.Value) + leafParam = integer_parameters.NewConcreteIntParam(t.Value) case *ParamName: - leafParam = leaf_parameters.NewVariableIntParam(t.Name) + leafParam = integer_parameters.NewVariableIntParam(t.Name) default: return nil, substraitgo.ErrNotImplemented } @@ -284,7 +284,7 @@ func (p *lengthType) ArgType() (types.FuncDefArgType, error) { return typ, nil } -func getParameterizedTypeSingleParam(typeName string, leafParam leaf_parameters.LeafParameter, n types.Nullability) (types.FuncDefArgType, error) { +func getParameterizedTypeSingleParam(typeName string, leafParam integer_parameters.IntegerParameter, n types.Nullability) (types.FuncDefArgType, error) { switch types.TypeName(typeName) { case types.TypeNameVarChar: return &types.ParameterizedVarCharType{IntegerOption: leafParam, Nullability: n}, nil @@ -326,20 +326,20 @@ func (d *decimalType) ArgType() (types.FuncDefArgType, error) { } else { n = types.NullabilityRequired } - var precision leaf_parameters.LeafParameter + var precision integer_parameters.IntegerParameter if pi, ok := d.Precision.Expr.(*IntegerLiteral); ok { - precision = leaf_parameters.NewConcreteIntParam(pi.Value) + precision = integer_parameters.NewConcreteIntParam(pi.Value) } else { ps := d.Precision.Expr.(*ParamName) - precision = leaf_parameters.NewVariableIntParam(ps.String()) + precision = integer_parameters.NewVariableIntParam(ps.String()) } - var scale leaf_parameters.LeafParameter + var scale integer_parameters.IntegerParameter if si, ok := d.Scale.Expr.(*IntegerLiteral); ok { - scale = leaf_parameters.NewConcreteIntParam(si.Value) + scale = integer_parameters.NewConcreteIntParam(si.Value) } else { ss := d.Scale.Expr.(*ParamName) - scale = leaf_parameters.NewVariableIntParam(ss.String()) + scale = integer_parameters.NewVariableIntParam(ss.String()) } return &types.ParameterizedDecimalType{ diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index ddc600b..ba09bd6 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -9,18 +9,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/leaf_parameters" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) func TestParser(t *testing.T) { - parameterLeaf_L1 := leaf_parameters.NewVariableIntParam("L1") - parameterLeaf_P := leaf_parameters.NewVariableIntParam("P") - parameterLeaf_S := leaf_parameters.NewVariableIntParam("S") - concreteLeaf_5 := leaf_parameters.NewConcreteIntParam(5) - concreteLeaf_38 := leaf_parameters.NewConcreteIntParam(38) - concreteLeaf_10 := leaf_parameters.NewConcreteIntParam(10) - concreteLeaf_EMinus5 := leaf_parameters.NewConcreteIntParam(int32(types.PrecisionEMinus5Seconds)) + parameterLeaf_L1 := integer_parameters.NewVariableIntParam("L1") + parameterLeaf_P := integer_parameters.NewVariableIntParam("P") + parameterLeaf_S := integer_parameters.NewVariableIntParam("S") + concreteLeaf_5 := integer_parameters.NewConcreteIntParam(5) + concreteLeaf_38 := integer_parameters.NewConcreteIntParam(38) + concreteLeaf_10 := integer_parameters.NewConcreteIntParam(10) + concreteLeaf_EMinus5 := integer_parameters.NewConcreteIntParam(int32(types.PrecisionEMinus5Seconds)) tests := []struct { expr string expected string diff --git a/types/types.go b/types/types.go index c68e9ef..25f05f5 100644 --- a/types/types.go +++ b/types/types.go @@ -374,7 +374,11 @@ type ( // CompositeType this represents a concrete type having components CompositeType interface { Type + // ParameterString this returns parameter string + // for e.g. parameter decimal, ParameterString returns "P,S" ParameterString() string + // BaseString this returns long name for parameter string + // for e.g. parameter decimal, BaseString returns "decimal" BaseString() string } @@ -382,8 +386,13 @@ type ( // These type can't be present in plan (not serializable) FuncDefArgType interface { fmt.Stringer + //SetNullability set nullability as given argument SetNullability(Nullability) FuncDefArgType + // HasParameterizedParam returns true if the type has at least one parameterized parameters + // if all parameters are concrete then it returns false HasParameterizedParam() bool + // GetParameterizedParams returns all parameterized parameters + // it doesn't return concrete parameters GetParameterizedParams() []interface{} } @@ -656,11 +665,11 @@ type ( FixedCharType = FixedLenType[FixedChar] VarCharType = FixedLenType[VarChar] FixedBinaryType = FixedLenType[FixedBinary] - ParameterizedVarCharType = parameterizedTypeSingleIntegerParam[VarCharType] - ParameterizedFixedCharType = parameterizedTypeSingleIntegerParam[FixedCharType] - ParameterizedFixedBinaryType = parameterizedTypeSingleIntegerParam[FixedBinaryType] - ParameterizedPrecisionTimestampType = parameterizedTypeSingleIntegerParam[PrecisionTimestampType] - ParameterizedPrecisionTimestampTzType = parameterizedTypeSingleIntegerParam[PrecisionTimestampTzType] + ParameterizedVarCharType = parameterizedTypeSingleIntegerParam[*VarCharType] + ParameterizedFixedCharType = parameterizedTypeSingleIntegerParam[*FixedCharType] + ParameterizedFixedBinaryType = parameterizedTypeSingleIntegerParam[*FixedBinaryType] + ParameterizedPrecisionTimestampType = parameterizedTypeSingleIntegerParam[*PrecisionTimestampType] + ParameterizedPrecisionTimestampTzType = parameterizedTypeSingleIntegerParam[*PrecisionTimestampTzType] ) // FixedLenType is any of the types which also need to track their specific