diff --git a/expr/functions.go b/expr/functions.go index 9b16687..1132e78 100644 --- a/expr/functions.go +++ b/expr/functions.go @@ -478,11 +478,13 @@ func (w *WindowFunction) Invocation() types.AggregationInvocation { return w.inv func (w *WindowFunction) Decomposable() extensions.DecomposeType { return w.declaration.Decomposability() } -func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() } -func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() } -func (w *WindowFunction) IntermediateType() (types.Type, error) { return w.declaration.Intermediate() } -func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() } -func (*WindowFunction) IsScalar() bool { return false } +func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() } +func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() } +func (w *WindowFunction) IntermediateType() (types.FuncDefArgType, error) { + return w.declaration.Intermediate() +} +func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() } +func (*WindowFunction) IsScalar() bool { return false } func (*WindowFunction) isRootRef() {} @@ -773,7 +775,7 @@ func (a *AggregateFunction) Decomposable() extensions.DecomposeType { } func (a *AggregateFunction) Ordered() bool { return a.declaration.Ordered() } func (a *AggregateFunction) MaxSet() int { return a.declaration.MaxSet() } -func (a *AggregateFunction) IntermediateType() (types.Type, error) { +func (a *AggregateFunction) IntermediateType() (types.FuncDefArgType, error) { return a.declaration.Intermediate() } diff --git a/expr/string_test.go b/expr/string_test.go index 8f1e63b..3fe2acf 100644 --- a/expr/string_test.go +++ b/expr/string_test.go @@ -35,7 +35,7 @@ func TestLiteralToString(t *testing.T) { Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false), }, }, true), - }, true), "list?>>([map?>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"}, + }, true), "list?>>([map?>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"}, {MustLiteral(expr.NewLiteral(float32(1.5), false)), "fp32(1.5)"}, {MustLiteral(expr.NewLiteral(&types.VarChar{Value: "foobar", Length: 7}, true)), "varchar?<7>(foobar)"}, {expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precisiontimestamp?<0>(123456)"}, diff --git a/extensions/simple_extension.go b/extensions/simple_extension.go index b191ea1..d9cc4ec 100644 --- a/extensions/simple_extension.go +++ b/extensions/simple_extension.go @@ -57,6 +57,7 @@ type TypeVariation struct { type Argument interface { toTypeString() string + argumentMarker() // unexported marker method } type EnumArg struct { @@ -69,6 +70,8 @@ func (EnumArg) toTypeString() string { return "req" } +func (v EnumArg) argumentMarker() {} + type ValueArg struct { Name string `yaml:",omitempty"` Description string `yaml:",omitempty"` @@ -80,6 +83,8 @@ func (v ValueArg) toTypeString() string { return v.Value.Expr.(*parser.Type).ShortType() } +func (v ValueArg) argumentMarker() {} + type TypeArg struct { Name string `yaml:",omitempty"` Description string `yaml:",omitempty"` @@ -88,6 +93,8 @@ type TypeArg struct { func (TypeArg) toTypeString() string { return "type" } +func (v TypeArg) argumentMarker() {} + type ArgumentList []Argument func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error { diff --git a/extensions/variants.go b/extensions/variants.go index 5619be8..31752fa 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -8,6 +8,7 @@ import ( substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -65,7 +66,7 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx var outType types.Type if t, ok := expr.Expr.(*parser.Type); ok { var err error - outType, err = t.Type() + outType, err = t.RetType() if err != nil { return nil, err } @@ -259,9 +260,9 @@ func (s *AggregateFunctionVariant) ID() ID { return ID{URI: s.uri, Name: s.CompoundName()} } func (s *AggregateFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable } -func (s *AggregateFunctionVariant) Intermediate() (types.Type, error) { +func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error) { if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok { - return t.Type() + return t.ArgType() } return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType) } @@ -366,12 +367,65 @@ func (s *WindowFunctionVariant) ID() ID { return ID{URI: s.uri, Name: s.CompoundName()} } func (s *WindowFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable } -func (s *WindowFunctionVariant) Intermediate() (types.Type, error) { +func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) { if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok { - return t.Type() + return t.ArgType() } return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType) } func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered } func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet } func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType } + +// HasSyncParams This API returns if params share a leaf param name +func HasSyncParams(params []types.FuncDefArgType) bool { + // if any of the leaf parameters are same, it indicates parameters are same across parameters + existingParamMap := make(map[string]bool) + for _, p := range params { + if !p.HasParameterizedParam() { + // not a type which contains abstract parameters, so continue + continue + } + // get list of parameterized parameters + // parameterized param can be a Leaf or another type. If another type we recurse to find leaf + abstractParams := p.GetParameterizedParams() + var leafParams []string + for _, abstractParam := range abstractParams { + leafParams = append(leafParams, getLeafParameterizedParams(abstractParam)...) + } + // if map contains any of the leaf params, parameters are synced + for _, leafParam := range leafParams { + if _, ok := existingParamMap[leafParam]; ok { + return true + } + } + // add all params to map, kindly note we can't add these params + // in previous loop to avoid having same leaf abstract type in same param + // e.g. Decimal has no sync param + for _, leafParam := range leafParams { + existingParamMap[leafParam] = true + } + } + return false +} + +// from a parameterized type, get the leaf parameters +// an parameterized param can be a leaf type (e.g. P) or a parameterized type (e.g. VARCHAR) itself +// if it is a leaf type, its param name is returned +// if it is parameterized type, leaf type is found recursively +func getLeafParameterizedParams(abstractTypes interface{}) []string { + if leaf, ok := abstractTypes.(integer_parameters.IntegerParameter); ok { + return []string{leaf.String()} + } + // if it is not a leaf type recurse + if pat, ok := abstractTypes.(types.FuncDefArgType); ok { + var outLeafParams []string + for _, p := range pat.GetParameterizedParams() { + childLeafParams := getLeafParameterizedParams(p) + outLeafParams = append(outLeafParams, childLeafParams...) + } + return outLeafParams + } + // invalid type + panic("invalid non-leaf, non-parameterized type param") +} diff --git a/extensions/variants_test.go b/extensions/variants_test.go index dee19e2..79280c9 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) @@ -65,3 +67,42 @@ func TestEvaluateTypeExpression(t *testing.T) { }) } } + +func TestHasSyncParams(t *testing.T) { + + apt_P := integer_parameters.NewVariableIntParam("P") + apt_Q := integer_parameters.NewVariableIntParam("Q") + cpt_38 := integer_parameters.NewConcreteIntParam(38) + + fct_P := &types.ParameterizedFixedCharType{IntegerOption: apt_P} + fct_Q := &types.ParameterizedFixedCharType{IntegerOption: apt_Q} + decimal_PQ := &types.ParameterizedDecimalType{Precision: apt_P, Scale: apt_Q} + decimal_38_Q := &types.ParameterizedDecimalType{Precision: cpt_38, Scale: apt_Q} + list_decimal_38_Q := &types.ParameterizedListType{Type: decimal_38_Q} + map_fctQ_decimal38Q := &types.ParameterizedMapType{Key: fct_Q, Value: decimal_38_Q} + struct_fctQ_ListDecimal38Q := &types.ParameterizedStructType{Types: []types.FuncDefArgType{fct_Q, list_decimal_38_Q}} + for _, td := range []struct { + name string + params []types.FuncDefArgType + expectedHasSyncParams bool + }{ + {"No Abstract Type", []types.FuncDefArgType{&types.Int64Type{}}, false}, + {"No Sync Param P, Q", []types.FuncDefArgType{fct_P, fct_Q}, false}, + {"Sync Params P, P", []types.FuncDefArgType{fct_P, fct_P}, true}, + {"Sync Params P, ", []types.FuncDefArgType{fct_P, decimal_PQ}, true}, + {"No Sync Params P, <38, Q>", []types.FuncDefArgType{fct_P, decimal_38_Q}, false}, + {"Sync Params P, List>", []types.FuncDefArgType{fct_P, list_decimal_38_Q}, false}, + {"No Sync Params fct

, Map, decimal<38,Q>>", []types.FuncDefArgType{fct_P, map_fctQ_decimal38Q}, false}, + {"Sync Params fct, Map, decimal<38,Q>>", []types.FuncDefArgType{fct_Q, map_fctQ_decimal38Q}, true}, + {"No Sync Params fct

, struct, list<38,Q>>", []types.FuncDefArgType{fct_P, struct_fctQ_ListDecimal38Q}, false}, + {"Sync Params fct, struct, list<38,Q>>", []types.FuncDefArgType{fct_Q, struct_fctQ_ListDecimal38Q}, true}, + } { + t.Run(td.name, func(t *testing.T) { + if td.expectedHasSyncParams { + require.True(t, extensions.HasSyncParams(td.params)) + } else { + require.False(t, extensions.HasSyncParams(td.params)) + } + }) + } +} diff --git a/functions/types.go b/functions/types.go index e61e289..3de9d06 100644 --- a/functions/types.go +++ b/functions/types.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package functions import ( @@ -147,14 +149,14 @@ type typeInfo struct { func (ti *typeInfo) getLongName() string { switch ti.typ.(type) { - case types.ParameterizedType: - return ti.typ.(types.ParameterizedType).BaseString() + case types.CompositeType: + return ti.typ.(types.CompositeType).BaseString() } return ti.typ.String() } func (ti *typeInfo) getLocalTypeString(input types.Type, enclosure typeEnclosure) string { - if paramType, ok := input.(types.ParameterizedType); ok { + if paramType, ok := input.(types.CompositeType); ok { return ti.localName + enclosure.containerStart() + paramType.ParameterString() + enclosure.containerEnd() } return ti.localName diff --git a/types/any_type.go b/types/any_type.go new file mode 100644 index 0000000..f6561ae --- /dev/null +++ b/types/any_type.go @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" +) + +// AnyType to represent AnyType, this type is to indicate "any" type of argument +// This type is not used in function invocation. It is only used in function definition +type AnyType struct { + Name string + TypeVariationRef uint32 + Nullability Nullability +} + +func (t AnyType) SetNullability(n Nullability) FuncDefArgType { + t.Nullability = n + return t +} + +func (t AnyType) String() string { + return fmt.Sprintf("%s%s", t.Name, strFromNullability(t.Nullability)) +} + +func (s AnyType) HasParameterizedParam() bool { + // primitive type doesn't have abstract parameters + return false +} + +func (s AnyType) GetParameterizedParams() []interface{} { + // any type doesn't have any abstract parameters + return nil +} diff --git a/types/any_type_test.go b/types/any_type_test.go new file mode 100644 index 0000000..1e8f26a --- /dev/null +++ b/types/any_type_test.go @@ -0,0 +1,31 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" +) + +func TestAnyType(t *testing.T) { + for _, td := range []struct { + testName string + argName string + nullability types.Nullability + expectedString string + }{ + {"any", "any", types.NullabilityNullable, "any?"}, + {"anyrequired", "any", types.NullabilityRequired, "any"}, + {"anyOtherName", "any1", types.NullabilityNullable, "any1?"}, + {"T name", "T", types.NullabilityNullable, "T?"}, + } { + t.Run(td.testName, func(t *testing.T) { + arg := &types.AnyType{ + Name: td.argName, + Nullability: td.nullability, + } + anyType := arg.SetNullability(td.nullability) + require.Equal(t, td.expectedString, anyType.String()) + }) + } +} diff --git a/types/integer_parameters/concrete_int_param.go b/types/integer_parameters/concrete_int_param.go new file mode 100644 index 0000000..dd3765a --- /dev/null +++ b/types/integer_parameters/concrete_int_param.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 + +package integer_parameters + +import "fmt" + +// ConcreteIntParam represents a single integer concrete parameter for a concrete type +// Example: VARCHAR(6) -> 6 is an ConcreteIntParam +// DECIMAL --> 0 Is an ConcreteIntParam but P not +type ConcreteIntParam int32 + +func NewConcreteIntParam(v int32) IntegerParameter { + m := ConcreteIntParam(v) + return &m +} + +func (m *ConcreteIntParam) IsCompatible(o IntegerParameter) bool { + if t, ok := o.(*ConcreteIntParam); ok { + return t == m + } + return false +} + +func (m *ConcreteIntParam) String() string { + return fmt.Sprintf("%d", *m) +} diff --git a/types/integer_parameters/integer_parameter_type.go b/types/integer_parameters/integer_parameter_type.go new file mode 100644 index 0000000..8c562a1 --- /dev/null +++ b/types/integer_parameters/integer_parameter_type.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 + +package integer_parameters + +import "fmt" + +// IntegerParameter represents a parameter type +// parameter can of concrete (38) or abstract type (P) +// or another parameterized type like VARCHAR<"L1"> +type IntegerParameter interface { + // IsCompatible is type compatible with other + // compatible is other can be used in place of this type + IsCompatible(other IntegerParameter) bool + fmt.Stringer +} diff --git a/types/integer_parameters/integer_parameter_type_test.go b/types/integer_parameters/integer_parameter_type_test.go new file mode 100644 index 0000000..0a47de2 --- /dev/null +++ b/types/integer_parameters/integer_parameter_type_test.go @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 + +package integer_parameters_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestConcreteParameterType(t *testing.T) { + concreteType1 := integer_parameters.ConcreteIntParam(1) + require.Equal(t, "1", concreteType1.String()) +} + +func TestLeafParameterType(t *testing.T) { + var concreteType1, concreteType2, abstractType1 integer_parameters.IntegerParameter + + concreteType1 = integer_parameters.NewConcreteIntParam(1) + concreteType2 = integer_parameters.NewConcreteIntParam(2) + + abstractType1 = integer_parameters.NewVariableIntParam("P") + + // verify string val + require.Equal(t, "1", concreteType1.String()) + require.Equal(t, "P", abstractType1.String()) + + // concrete type is only compatible with same type + require.True(t, concreteType1.IsCompatible(concreteType1)) + require.False(t, concreteType1.IsCompatible(concreteType2)) + + // abstract type is compatible with both abstract and concrete type + require.True(t, abstractType1.IsCompatible(abstractType1)) + require.True(t, abstractType1.IsCompatible(concreteType2)) +} diff --git a/types/integer_parameters/variable_int_param.go b/types/integer_parameters/variable_int_param.go new file mode 100644 index 0000000..76fd5c2 --- /dev/null +++ b/types/integer_parameters/variable_int_param.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 + +package integer_parameters + +// VariableIntParam represents an integer parameter for a parameterized type +// Example: VARCHAR(L1) -> L1 is an VariableIntParam +// DECIMAL --> P Is an VariableIntParam +type VariableIntParam string + +func NewVariableIntParam(s string) IntegerParameter { + m := VariableIntParam(s) + return &m +} + +func (m *VariableIntParam) IsCompatible(o IntegerParameter) bool { + switch o.(type) { + case *VariableIntParam, *ConcreteIntParam: + return true + default: + return false + } +} + +func (m *VariableIntParam) String() string { + return string(*m) +} + +func (m *VariableIntParam) GetAbstractParamName() string { + return string(*m) +} diff --git a/types/parameterized_decimal_type.go b/types/parameterized_decimal_type.go new file mode 100644 index 0000000..2e0fb37 --- /dev/null +++ b/types/parameterized_decimal_type.go @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" + + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +// ParameterizedDecimalType is a decimal type which to hold function arguments +// example: Decimal or Decimal or Decimal(10, 2) +type ParameterizedDecimalType struct { + Nullability Nullability + TypeVariationRef uint32 + Precision integer_parameters.IntegerParameter + Scale integer_parameters.IntegerParameter +} + +func (m *ParameterizedDecimalType) SetNullability(n Nullability) FuncDefArgType { + m.Nullability = n + return m +} + +func (m *ParameterizedDecimalType) String() string { + t := DecimalType{} + parameterString := fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedDecimalType) HasParameterizedParam() bool { + _, ok1 := m.Precision.(*integer_parameters.VariableIntParam) + _, ok2 := m.Scale.(*integer_parameters.VariableIntParam) + return ok1 || ok2 +} + +func (m *ParameterizedDecimalType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + var params []interface{} + if p, ok := m.Precision.(*integer_parameters.VariableIntParam); ok { + params = append(params, p) + } + if p, ok := m.Scale.(*integer_parameters.VariableIntParam); ok { + params = append(params, p) + } + return params +} diff --git a/types/parameterized_decimal_type_test.go b/types/parameterized_decimal_type_test.go new file mode 100644 index 0000000..5729dee --- /dev/null +++ b/types/parameterized_decimal_type_test.go @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestParameterizedDecimalType(t *testing.T) { + precision_P := integer_parameters.NewVariableIntParam("P") + scale_S := integer_parameters.NewVariableIntParam("S") + precision_38 := integer_parameters.NewConcreteIntParam(38) + scale_5 := integer_parameters.NewConcreteIntParam(5) + for _, td := range []struct { + name string + precision integer_parameters.IntegerParameter + scale integer_parameters.IntegerParameter + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} + }{ + {"both parameterized", precision_P, scale_S, "decimal?", "decimal", true, []interface{}{precision_P, scale_S}}, + {"precision concrete", precision_38, scale_S, "decimal?<38,S>", "decimal<38,S>", true, []interface{}{scale_S}}, + {"scale concrete", precision_P, scale_5, "decimal?", "decimal", true, []interface{}{precision_P}}, + {"both concrete", precision_38, scale_5, "decimal?<38,5>", "decimal<38,5>", false, nil}, + } { + t.Run(td.name, func(t *testing.T) { + pd := &types.ParameterizedDecimalType{Precision: td.precision, Scale: td.scale} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) + }) + } +} diff --git a/types/parameterized_list_type.go b/types/parameterized_list_type.go new file mode 100644 index 0000000..e37da33 --- /dev/null +++ b/types/parameterized_list_type.go @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" +) + +// ParameterizedListType is a list type having parameter of ParameterizedAbstractType +// basically a list of which type is another abstract parameter +// example: List. Kindly note concrete types List is not represented by this type +// Concrete type is represented by ListType +type ParameterizedListType struct { + Nullability Nullability + TypeVariationRef uint32 + Type FuncDefArgType +} + +func (m *ParameterizedListType) SetNullability(n Nullability) FuncDefArgType { + m.Nullability = n + return m +} + +func (m *ParameterizedListType) String() string { + t := ListType{} + parameterString := fmt.Sprintf("<%s>", m.Type) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedListType) HasParameterizedParam() bool { + return m.Type.HasParameterizedParam() +} + +func (m *ParameterizedListType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + return []interface{}{m.Type} +} diff --git a/types/parameterized_list_type_test.go b/types/parameterized_list_type_test.go new file mode 100644 index 0000000..da54387 --- /dev/null +++ b/types/parameterized_list_type_test.go @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestParameterizedListType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), + Nullability: types.NullabilityRequired, + } + int8Type := &types.Int8Type{} + for _, td := range []struct { + name string + param types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} + }{ + {"parameterized param", decimalType, "list?>", "list>", true, []interface{}{decimalType}}, + {"concrete param", int8Type, "list?", "list", false, nil}, + } { + t.Run(td.name, func(t *testing.T) { + pd := &types.ParameterizedListType{Type: td.param} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) + }) + } +} diff --git a/types/parameterized_map_type.go b/types/parameterized_map_type.go new file mode 100644 index 0000000..1c79c14 --- /dev/null +++ b/types/parameterized_map_type.go @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" +) + +// ParameterizedMapType is a struct having at least one of key or value of type ParameterizedAbstractType +// If All arguments are concrete they are represented by MapType +type ParameterizedMapType struct { + Nullability Nullability + TypeVariationRef uint32 + Key FuncDefArgType + Value FuncDefArgType +} + +func (m *ParameterizedMapType) SetNullability(n Nullability) FuncDefArgType { + m.Nullability = n + return m +} + +func (m *ParameterizedMapType) String() string { + t := MapType{} + parameterString := fmt.Sprintf("<%s, %s>", m.Key.String(), m.Value.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedMapType) HasParameterizedParam() bool { + return m.Key.HasParameterizedParam() || m.Value.HasParameterizedParam() +} + +func (m *ParameterizedMapType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + var abstractParams []interface{} + if m.Key.HasParameterizedParam() { + abstractParams = append(abstractParams, m.Key) + } + if m.Value.HasParameterizedParam() { + abstractParams = append(abstractParams, m.Value) + } + return abstractParams +} diff --git a/types/parameterized_map_type_test.go b/types/parameterized_map_type_test.go new file mode 100644 index 0000000..62378ba --- /dev/null +++ b/types/parameterized_map_type_test.go @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestParameterizedMapType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), + Nullability: types.NullabilityRequired, + } + int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} + listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} + for _, td := range []struct { + name string + Key types.FuncDefArgType + Value types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} + }{ + {"parameterized kv", decimalType, listType, "map?, list?>>", "map, list?>>", true, []interface{}{decimalType, listType}}, + {"concrete key", int8Type, listType, "map?>>", "map>>", true, []interface{}{listType}}, + {"concrete value", decimalType, int8Type, "map?, i8?>", "map, i8?>", true, []interface{}{decimalType}}, + {"no parameterized param", int8Type, int8Type, "map?", "map", false, nil}, + } { + t.Run(td.name, func(t *testing.T) { + pd := &types.ParameterizedMapType{Key: td.Key, Value: td.Value} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) + }) + } +} diff --git a/types/parameterized_single_integer_param_type.go b/types/parameterized_single_integer_param_type.go new file mode 100644 index 0000000..4582aea --- /dev/null +++ b/types/parameterized_single_integer_param_type.go @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" + "reflect" + + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +type singleIntegerParamType interface { + BaseString() string +} + +// parameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter +type parameterizedTypeSingleIntegerParam[T singleIntegerParamType] struct { + Nullability Nullability + TypeVariationRef uint32 + IntegerOption integer_parameters.IntegerParameter +} + +func (m *parameterizedTypeSingleIntegerParam[T]) SetNullability(n Nullability) FuncDefArgType { + m.Nullability = n + return m +} + +func (m *parameterizedTypeSingleIntegerParam[T]) String() string { + return fmt.Sprintf("%s%s%s", m.baseString(), strFromNullability(m.Nullability), m.parameterString()) +} + +func (m *parameterizedTypeSingleIntegerParam[T]) parameterString() string { + return fmt.Sprintf("<%s>", m.IntegerOption.String()) +} + +func (m *parameterizedTypeSingleIntegerParam[T]) baseString() string { + var t T + tType := reflect.TypeOf(t) + if tType.Kind() == reflect.Ptr { + tType = tType.Elem() + } + newInstance := reflect.New(tType).Interface().(T) + return newInstance.BaseString() +} + +func (m *parameterizedTypeSingleIntegerParam[T]) HasParameterizedParam() bool { + _, ok1 := m.IntegerOption.(*integer_parameters.VariableIntParam) + return ok1 +} + +func (m *parameterizedTypeSingleIntegerParam[T]) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + return []interface{}{m.IntegerOption} +} diff --git a/types/parameterized_single_integer_param_type_test.go b/types/parameterized_single_integer_param_type_test.go new file mode 100644 index 0000000..1c98491 --- /dev/null +++ b/types/parameterized_single_integer_param_type_test.go @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestParameterizedSingleIntegerType(t *testing.T) { + abstractLeafParam_L1 := integer_parameters.NewVariableIntParam("L1") + concreteLeafParam_38 := integer_parameters.NewConcreteIntParam(38) + for _, td := range []struct { + name string + typ types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedIsParameterized bool + expectedAbstractParams []interface{} + }{ + {"nullable parameterized varchar", &types.ParameterizedVarCharType{IntegerOption: abstractLeafParam_L1}, "varchar?", "varchar", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete varchar", &types.ParameterizedVarCharType{IntegerOption: concreteLeafParam_38}, "varchar?<38>", "varchar<38>", false, nil}, + {"nullable fixChar", &types.ParameterizedFixedCharType{IntegerOption: abstractLeafParam_L1}, "char?", "char", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete fixChar", &types.ParameterizedFixedCharType{IntegerOption: concreteLeafParam_38}, "char?<38>", "char<38>", false, nil}, + {"nullable fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: abstractLeafParam_L1}, "fixedbinary?", "fixedbinary", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete fixBinary", &types.ParameterizedFixedBinaryType{IntegerOption: concreteLeafParam_38}, "fixedbinary?<38>", "fixedbinary<38>", false, nil}, + {"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: abstractLeafParam_L1}, "precision_timestamp?", "precision_timestamp", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{IntegerOption: concreteLeafParam_38}, "precision_timestamp?<38>", "precision_timestamp<38>", false, nil}, + {"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: abstractLeafParam_L1}, "precision_timestamp_tz?", "precision_timestamp_tz", true, []interface{}{abstractLeafParam_L1}}, + {"nullable concrete precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: concreteLeafParam_38}, "precision_timestamp_tz?<38>", "precision_timestamp_tz<38>", false, nil}, + } { + t.Run(td.name, func(t *testing.T) { + require.Equal(t, td.expectedNullableString, td.typ.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, td.typ.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedIsParameterized, td.typ.HasParameterizedParam()) + require.Equal(t, td.expectedAbstractParams, td.typ.GetParameterizedParams()) + }) + } +} diff --git a/types/parameterized_struct_type.go b/types/parameterized_struct_type.go new file mode 100644 index 0000000..30688cb --- /dev/null +++ b/types/parameterized_struct_type.go @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "fmt" + "strings" +) + +// ParameterizedStructType is a parameter type struct +// example: Struct or Struct. +type ParameterizedStructType struct { + Nullability Nullability + TypeVariationRef uint32 + Types []FuncDefArgType +} + +func (m *ParameterizedStructType) SetNullability(n Nullability) FuncDefArgType { + m.Nullability = n + return m +} + +func (m *ParameterizedStructType) String() string { + sb := strings.Builder{} + for i, typ := range m.Types { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(typ.String()) + } + t := StructType{} + parameterString := fmt.Sprintf("<%s>", sb.String()) + return fmt.Sprintf("%s%s%s", t.BaseString(), strFromNullability(m.Nullability), parameterString) +} + +func (m *ParameterizedStructType) HasParameterizedParam() bool { + for _, typ := range m.Types { + if typ.HasParameterizedParam() { + return true + } + } + return false +} + +func (m *ParameterizedStructType) GetParameterizedParams() []interface{} { + if !m.HasParameterizedParam() { + return nil + } + var abstractParams []interface{} + for _, typ := range m.Types { + if typ.HasParameterizedParam() { + abstractParams = append(abstractParams, typ) + } + } + return abstractParams +} diff --git a/types/parameterized_struct_type_test.go b/types/parameterized_struct_type_test.go new file mode 100644 index 0000000..68c3bf4 --- /dev/null +++ b/types/parameterized_struct_type_test.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" +) + +func TestParameterizedStructType(t *testing.T) { + decimalType := &types.ParameterizedDecimalType{ + Precision: integer_parameters.NewVariableIntParam("P"), + Scale: integer_parameters.NewVariableIntParam("S"), + Nullability: types.NullabilityRequired, + } + int8Type := &types.Int8Type{Nullability: types.NullabilityNullable} + listType := &types.ParameterizedListType{Type: decimalType, Nullability: types.NullabilityNullable} + for _, td := range []struct { + name string + params []types.FuncDefArgType + expectedNullableString string + expectedNullableRequiredString string + expectedHasParameterizedParam bool + expectedParameterizedParams []interface{} + }{ + {"all parameterized param", []types.FuncDefArgType{decimalType, listType}, "struct?, list?>>", "struct, list?>>", true, []interface{}{decimalType, listType}}, + {"mix parameterized concrete param", []types.FuncDefArgType{decimalType, int8Type, listType}, "struct?, i8?, list?>>", "struct, i8?, list?>>", true, []interface{}{decimalType, listType}}, + {"all concrete param", []types.FuncDefArgType{int8Type, int8Type, int8Type}, "struct?", "struct", false, nil}, + } { + t.Run(td.name, func(t *testing.T) { + pd := &types.ParameterizedStructType{Types: td.params} + require.Equal(t, td.expectedNullableString, pd.SetNullability(types.NullabilityNullable).String()) + require.Equal(t, td.expectedNullableRequiredString, pd.SetNullability(types.NullabilityRequired).String()) + require.Equal(t, td.expectedHasParameterizedParam, pd.HasParameterizedParam()) + require.Equal(t, td.expectedParameterizedParams, pd.GetParameterizedParams()) + }) + } +} diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index ae5e4a2..4e05aa0 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -3,6 +3,7 @@ package parser import ( + "errors" "io" "strconv" "strings" @@ -11,6 +12,7 @@ import ( "github.com/alecthomas/participle/v2/lexer" substraitgo "github.com/substrait-io/substrait-go" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" ) var defaultParser *Parser @@ -21,6 +23,14 @@ type TypeExpression struct { func (t TypeExpression) String() string { return t.Expr.String() } +func (t TypeExpression) Type() (types.FuncDefArgType, error) { + typeDef, ok := t.Expr.(Def) + if !ok { + return nil, errors.New("type expression doesn't represent type") + } + return typeDef.ArgType() +} + func (t TypeExpression) MarshalYAML() (interface{}, error) { return t.Expr.String(), nil } @@ -85,15 +95,23 @@ func (t *Type) String() string { return t.TypeDef.String() } -func (t *Type) Type() (types.Type, error) { - return t.TypeDef.Type() +func (t *Type) ArgType() (types.FuncDefArgType, error) { + return t.TypeDef.ArgType() +} + +func (t *Type) RetType() (types.Type, error) { + return t.TypeDef.RetType() } type Def interface { String() string ShortType() string - Type() (types.Type, error) + // ArgType indicates argument type + ArgType() (types.FuncDefArgType, error) Optional() bool + // TODO RetType indicates return type + // This should be replaced with TypeDerivation method. Currently it just returns concrete type + RetType() (types.Type, error) } type typename string @@ -104,7 +122,7 @@ func (t *typename) Capture(values []string) error { } type nonParamType struct { - TypeName typename `parser:"@(AnyType | Template | IntType | Boolean | FPType | Temporal | BinaryType)"` + TypeName typename `parser:"@(IntType | Boolean | FPType | Temporal | BinaryType)"` Nullability bool `parser:"@'?'?"` // Variation int `parser:"'[' @\d+ ']'?"` } @@ -120,14 +138,10 @@ func (t *nonParamType) String() string { } func (t *nonParamType) ShortType() string { - if strings.HasPrefix(string(t.TypeName), "any") { - return "any" - } - return types.GetShortTypeName(types.TypeName(t.TypeName)) } -func (t *nonParamType) Type() (types.Type, error) { +func (t *nonParamType) RetType() (types.Type, error) { var n types.Nullability if t.Nullability { n = types.NullabilityNullable @@ -141,6 +155,21 @@ func (t *nonParamType) Type() (types.Type, error) { return nil, err } +func (t *nonParamType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + typ, err := types.SimpleTypeNameToType(types.TypeName(t.TypeName)) + if err != nil { + return nil, err + } + funcArgType := typ.(types.FuncDefArgType) + return funcArgType.SetNullability(n), nil +} + type listType struct { Nullability bool `parser:"'list' @'?'?"` ElemType TypeExpression `parser:"'<' @@ '>'"` @@ -158,7 +187,7 @@ func (l *listType) String() string { func (l *listType) Optional() bool { return false } -func (l *listType) Type() (types.Type, error) { +func (l *listType) RetType() (types.Type, error) { var n types.Nullability if l.Nullability { n = types.NullabilityNullable @@ -166,7 +195,7 @@ func (l *listType) Type() (types.Type, error) { n = types.NullabilityRequired } if t, ok := l.ElemType.Expr.(*Type); ok { - ret, err := t.Type() + ret, err := t.RetType() if err != nil { return nil, err } @@ -179,6 +208,26 @@ func (l *listType) Type() (types.Type, error) { return nil, substraitgo.ErrNotImplemented } +func (l *listType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if l.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + if t, ok := l.ElemType.Expr.(*Type); ok { + ret, err := t.ArgType() + if err != nil { + return nil, err + } + return &types.ParameterizedListType{ + Nullability: n, + Type: ret, + }, nil + } + return nil, substraitgo.ErrNotImplemented +} + type lengthType struct { TypeName string `parser:"@LengthType '<'"` NumericParam TypeExpression `parser:"@@ '>'"` @@ -188,6 +237,10 @@ func (p *lengthType) ShortType() string { switch p.TypeName { case "fixedchar", "varchar", "fixedbinary": return types.GetShortTypeName(types.TypeName(p.TypeName)) + case "precision_timestamp": + return "prets" + case "precision_timestamp_tz": + return "pretstz" } return "" } @@ -198,7 +251,7 @@ func (p *lengthType) String() string { func (p *lengthType) Optional() bool { return false } -func (p *lengthType) Type() (types.Type, error) { +func (p *lengthType) RetType() (types.Type, error) { var n types.Nullability lit, ok := p.NumericParam.Expr.(*IntegerLiteral) if !ok { @@ -212,6 +265,42 @@ func (p *lengthType) Type() (types.Type, error) { return typ.WithLength(lit.Value).WithNullability(n), nil } +func (p *lengthType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + + var leafParam integer_parameters.IntegerParameter + switch t := p.NumericParam.Expr.(type) { + case *IntegerLiteral: + leafParam = integer_parameters.NewConcreteIntParam(t.Value) + case *ParamName: + leafParam = integer_parameters.NewVariableIntParam(t.Name) + default: + return nil, substraitgo.ErrNotImplemented + } + typ, err := getParameterizedTypeSingleParam(p.TypeName, leafParam, n) + if err != nil { + return nil, err + } + return typ, nil +} + +func getParameterizedTypeSingleParam(typeName string, leafParam integer_parameters.IntegerParameter, n types.Nullability) (types.FuncDefArgType, error) { + switch types.TypeName(typeName) { + case types.TypeNameVarChar: + return &types.ParameterizedVarCharType{IntegerOption: leafParam, Nullability: n}, nil + case types.TypeNameFixedChar: + return &types.ParameterizedFixedCharType{IntegerOption: leafParam, Nullability: n}, nil + case types.TypeNameFixedBinary: + return &types.ParameterizedFixedBinaryType{IntegerOption: leafParam, Nullability: n}, nil + case types.TypeNamePrecisionTimestamp: + return &types.ParameterizedPrecisionTimestampType{IntegerOption: leafParam, Nullability: n}, nil + case types.TypeNamePrecisionTimestampTz: + return &types.ParameterizedPrecisionTimestampTzType{IntegerOption: leafParam, Nullability: n}, nil + default: + return nil, substraitgo.ErrNotImplemented + } +} + type decimalType struct { Nullability bool `parser:"'decimal' @'?'?"` Precision TypeExpression `parser:"'<' @@"` @@ -225,23 +314,56 @@ func (d *decimalType) String() string { if d.Nullability { opt = "?" } - return "decimal" + opt + "<" + d.Precision.Expr.String() + ", " + d.Scale.Expr.String() + ">" + return "decimal" + opt + "<" + d.Precision.Expr.String() + "," + d.Scale.Expr.String() + ">" } func (d *decimalType) Optional() bool { return d.Nullability } -func (d *decimalType) Type() (types.Type, error) { +func (d *decimalType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if d.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + var precision integer_parameters.IntegerParameter + if pi, ok := d.Precision.Expr.(*IntegerLiteral); ok { + precision = integer_parameters.NewConcreteIntParam(pi.Value) + } else { + ps := d.Precision.Expr.(*ParamName) + precision = integer_parameters.NewVariableIntParam(ps.String()) + } + + var scale integer_parameters.IntegerParameter + if si, ok := d.Scale.Expr.(*IntegerLiteral); ok { + scale = integer_parameters.NewConcreteIntParam(si.Value) + } else { + ss := d.Scale.Expr.(*ParamName) + scale = integer_parameters.NewVariableIntParam(ss.String()) + } + + return &types.ParameterizedDecimalType{ + Nullability: n, + Precision: precision, + Scale: scale, + }, nil +} + +func (d *decimalType) RetType() (types.Type, error) { var n types.Nullability + if d.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } p, ok := d.Precision.Expr.(*IntegerLiteral) if !ok { return nil, substraitgo.ErrNotImplemented } - s, ok := d.Scale.Expr.(*IntegerLiteral) if !ok { return nil, substraitgo.ErrNotImplemented } - return &types.DecimalType{ Nullability: n, Precision: p.Value, @@ -275,7 +397,7 @@ func (s *structType) String() string { func (t *structType) Optional() bool { return t.Nullability } -func (t *structType) Type() (types.Type, error) { +func (t *structType) RetType() (types.Type, error) { var n types.Nullability if t.Nullability { n = types.NullabilityNullable @@ -290,7 +412,7 @@ func (t *structType) Type() (types.Type, error) { return nil, substraitgo.ErrNotImplemented } - if typeList[i], err = tp.Type(); err != nil { + if typeList[i], err = tp.RetType(); err != nil { return nil, err } } @@ -300,6 +422,31 @@ func (t *structType) Type() (types.Type, error) { }, nil } +func (t *structType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + var err error + typeList := make([]types.FuncDefArgType, len(t.Types)) + for i, typ := range t.Types { + tp, ok := typ.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + + if typeList[i], err = tp.ArgType(); err != nil { + return nil, err + } + } + return &types.ParameterizedStructType{ + Nullability: n, + Types: typeList, + }, nil +} + type mapType struct { Nullability bool `parser:"'map' @'?'?"` Key TypeExpression `parser:"'<' @@"` @@ -313,12 +460,12 @@ func (m *mapType) String() string { if m.Nullability { opt = "?" } - return "map" + opt + "<" + m.Key.Expr.String() + "," + m.Value.Expr.String() + ">" + return "map" + opt + "<" + m.Key.Expr.String() + ", " + m.Value.Expr.String() + ">" } func (m *mapType) Optional() bool { return m.Nullability } -func (m *mapType) Type() (types.Type, error) { +func (m *mapType) RetType() (types.Type, error) { var n types.Nullability if m.Nullability { n = types.NullabilityNullable @@ -336,12 +483,12 @@ func (m *mapType) Type() (types.Type, error) { return nil, substraitgo.ErrNotImplemented } - key, err := k.Type() + key, err := k.RetType() if err != nil { return nil, err } - value, err := v.Type() + value, err := v.RetType() if err != nil { return nil, err } @@ -352,6 +499,82 @@ func (m *mapType) Type() (types.Type, error) { }, nil } +func (m *mapType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if m.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + + k, ok := m.Key.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + + v, ok := m.Value.Expr.(*Type) + if !ok { + return nil, substraitgo.ErrNotImplemented + } + + key, err := k.ArgType() + if err != nil { + return nil, err + } + + value, err := v.ArgType() + if err != nil { + return nil, err + } + + return &types.ParameterizedMapType{ + Key: key, + Value: value, + Nullability: n, + }, nil +} + +// parser token for any +type anyType struct { + TypeName typename `parser:"@(AnyType|Template)"` + Nullability bool `parser:"@'?'?"` +} + +func (t anyType) Optional() bool { return t.Nullability } + +func (t anyType) String() string { + opt := string(t.TypeName) + if t.Nullability { + opt += "?" + } + return opt +} + +func (t anyType) ShortType() string { + if strings.HasPrefix(string(t.TypeName), "any") { + return "any" + } + return string(t.TypeName) +} + +func (t anyType) ArgType() (types.FuncDefArgType, error) { + var n types.Nullability + if t.Nullability { + n = types.NullabilityNullable + } else { + n = types.NullabilityRequired + } + typeName := string(t.TypeName) + if strings.HasPrefix(typeName, "any") { + return types.AnyType{Name: "any", Nullability: n}, nil + } + return types.AnyType{Name: typeName, Nullability: n}, nil +} + +func (t anyType) RetType() (types.Type, error) { + panic("any type can't be in return type") +} + var ( def = lexer.MustSimple([]lexer.SimpleRule{ {Name: "whitespace", Pattern: `[ \t]+`}, @@ -362,7 +585,7 @@ var ( {Name: "FPType", Pattern: `fp(32|64)`}, {Name: "Temporal", Pattern: `timestamp(_tz)?|date|time|interval_day|interval_year`}, {Name: "BinaryType", Pattern: `string|binary|uuid`}, - {Name: "LengthType", Pattern: `fixedchar|varchar|fixedbinary`}, + {Name: "LengthType", Pattern: `fixedchar|varchar|fixedbinary|precision_timestamp_tz|precision_timestamp`}, {Name: "Int", Pattern: `[-+]?\d+`}, {Name: "ParamType", Pattern: `(?i)(struct|list|decimal|map)`}, {Name: "Identifier", Pattern: `[a-zA-Z_$][a-zA-Z_$0-9]*`}, @@ -389,7 +612,7 @@ func (p *Parser) ParseBytes(expr []byte) (*TypeExpression, error) { func New() (*Parser, error) { parser, err := participle.Build[TypeExpression]( participle.Union[Expression](&Type{}, &IntegerLiteral{}, &ParamName{}), - participle.Union[Def](&nonParamType{}, &mapType{}, &listType{}, &structType{}, &lengthType{}, &decimalType{}), + participle.Union[Def](&anyType{}, &nonParamType{}, &mapType{}, &listType{}, &structType{}, &lengthType{}, &decimalType{}), participle.CaseInsensitive("Boolean", "ParamType", "IntType", "FPType", "Temporal", "BinaryType", "LengthType"), participle.Lexer(def), participle.UseLookahead(3), diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 0abc300..ba09bd6 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -3,47 +3,70 @@ package parser_test import ( + "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser" ) func TestParser(t *testing.T) { + parameterLeaf_L1 := integer_parameters.NewVariableIntParam("L1") + parameterLeaf_P := integer_parameters.NewVariableIntParam("P") + parameterLeaf_S := integer_parameters.NewVariableIntParam("S") + concreteLeaf_5 := integer_parameters.NewConcreteIntParam(5) + concreteLeaf_38 := integer_parameters.NewConcreteIntParam(38) + concreteLeaf_10 := integer_parameters.NewConcreteIntParam(10) + concreteLeaf_EMinus5 := integer_parameters.NewConcreteIntParam(int32(types.PrecisionEMinus5Seconds)) tests := []struct { - expr string - expected string - shortName string - typ types.Type + expr string + expected string + shortName string + expectedTyp types.FuncDefArgType }{ {"2", "2", "", nil}, {"-2", "-2", "", nil}, {"i16?", "i16?", "i16", &types.Int16Type{Nullability: types.NullabilityNullable}}, {"boolean", "boolean", "bool", &types.BooleanType{Nullability: types.NullabilityRequired}}, - {"fixedchar<5>", "fixedchar<5>", "fchar", &types.FixedCharType{Length: 5}}, - {"decimal<10,5>", "decimal<10, 5>", "dec", &types.DecimalType{Precision: 10, Scale: 5}}, - {"list>", "list>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5}, Nullability: types.NullabilityRequired}}, - {"list?>", "list?>", "list", &types.ListType{Type: &types.DecimalType{Precision: 10, Scale: 5}, Nullability: types.NullabilityNullable}}, - {"struct", "struct", "struct", &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, - {"map>", "map>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, - {"map?>", "map?>", "map", &types.MapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.StructType{Types: []types.Type{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"fixedchar<5>", "fixedchar<5>", "fchar", &types.ParameterizedFixedCharType{IntegerOption: concreteLeaf_5}}, + {"decimal<10,5>", "decimal<10,5>", "dec", &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityRequired}}, + {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"list?>", "list?>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: concreteLeaf_10, Scale: concreteLeaf_5, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"struct", "struct", "struct", &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, + {"map>", "map>", "map", &types.ParameterizedMapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityRequired}}, + {"map?>", "map?>", "map", &types.ParameterizedMapType{Key: &types.BooleanType{Nullability: types.NullabilityNullable}, Value: &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.Int16Type{Nullability: types.NullabilityNullable}, &types.Int32Type{Nullability: types.NullabilityNullable}, &types.Int64Type{Nullability: types.NullabilityNullable}}, Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}}, + {"precision_timestamp<5>", "precision_timestamp<5>", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: concreteLeaf_EMinus5}}, + {"precision_timestamp_tz<5>", "precision_timestamp_tz<5>", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: concreteLeaf_EMinus5}}, + {"varchar", "varchar", "vchar", &types.ParameterizedVarCharType{IntegerOption: parameterLeaf_L1}}, + {"fixedchar", "fixedchar", "fchar", &types.ParameterizedFixedCharType{IntegerOption: parameterLeaf_L1}}, + {"fixedbinary", "fixedbinary", "fbin", &types.ParameterizedFixedBinaryType{IntegerOption: parameterLeaf_L1}}, + {"precision_timestamp", "precision_timestamp", "prets", &types.ParameterizedPrecisionTimestampType{IntegerOption: parameterLeaf_L1}}, + {"precision_timestamp_tz", "precision_timestamp_tz", "pretstz", &types.ParameterizedPrecisionTimestampTzType{IntegerOption: parameterLeaf_L1}}, + {"decimal", "decimal", "dec", &types.ParameterizedDecimalType{Precision: parameterLeaf_S, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}}, + {"decimal<38,S>", "decimal<38,S>", "dec", &types.ParameterizedDecimalType{Precision: concreteLeaf_38, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}}, + {"any", "any", "any", types.AnyType{Nullability: types.NullabilityRequired}}, + {"any1?", "any1?", "any", types.AnyType{Nullability: types.NullabilityNullable}}, + {"list>", "list>", "list", &types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, + {"struct>, i16>", "struct>, i16>", "struct", &types.ParameterizedStructType{Types: []types.FuncDefArgType{&types.ParameterizedListType{Type: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Nullability: types.NullabilityNullable}, &types.Int16Type{Nullability: types.NullabilityRequired}}, Nullability: types.NullabilityRequired}}, + {"map, i16>", "map, i16>", "map", &types.ParameterizedMapType{Key: &types.ParameterizedDecimalType{Precision: parameterLeaf_P, Scale: parameterLeaf_S, Nullability: types.NullabilityRequired}, Value: &types.Int16Type{Nullability: types.NullabilityRequired}, Nullability: types.NullabilityRequired}}, } p, err := parser.New() require.NoError(t, err) - for _, tt := range tests { - t.Run(tt.expr, func(t *testing.T) { - d, err := p.ParseString(tt.expr) + for _, td := range tests { + t.Run(td.expr, func(t *testing.T) { + d, err := p.ParseString(td.expr) assert.NoError(t, err) - assert.Equal(t, tt.expected, d.Expr.String()) - if tt.shortName != "" { - assert.Equal(t, tt.shortName, d.Expr.(*parser.Type).ShortType()) - typ, err := d.Expr.(*parser.Type).Type() + assert.Equal(t, td.expected, d.Expr.String()) + if td.shortName != "" { + assert.Equal(t, td.shortName, d.Expr.(*parser.Type).ShortType()) + typ, err := d.Expr.(*parser.Type).ArgType() assert.NoError(t, err) - assert.True(t, tt.typ.Equals(typ)) + assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(typ)) } }) } diff --git a/types/precison_timestamp_types.go b/types/precison_timestamp_types.go index e27dd0c..0512ced 100644 --- a/types/precison_timestamp_types.go +++ b/types/precison_timestamp_types.go @@ -1,7 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( "fmt" + "reflect" "github.com/substrait-io/substrait-go/proto" ) @@ -97,6 +100,14 @@ func (m *PrecisionTimestampType) String() string { m.Precision.ToProtoVal()) } +func (m *PrecisionTimestampType) ParameterString() string { + return fmt.Sprintf("%d", m.Precision.ToProtoVal()) +} + +func (m *PrecisionTimestampType) BaseString() string { + return typeNames[reflect.TypeOf(m)] +} + // PrecisionTimestampTzType this is used to represent a type of Precision timestamp with TimeZone type PrecisionTimestampTzType struct { PrecisionTimestampType @@ -145,3 +156,7 @@ func (m *PrecisionTimestampTzType) Equals(rhs Type) bool { return false } func (*PrecisionTimestampTzType) ShortString() string { return "pretstz" } + +func (m *PrecisionTimestampTzType) BaseString() string { + return typeNames[reflect.TypeOf(m)] +} diff --git a/types/precison_timestamp_types_test.go b/types/precison_timestamp_types_test.go index d606be9..687fe34 100644 --- a/types/precison_timestamp_types_test.go +++ b/types/precison_timestamp_types_test.go @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + package types import ( diff --git a/types/types.go b/types/types.go index a837860..25f05f5 100644 --- a/types/types.go +++ b/types/types.go @@ -44,10 +44,12 @@ const ( TypeNameIntervalDay TypeName = "interval_day" TypeNameUUID TypeName = "uuid" - TypeNameFixedBinary TypeName = "fixedbinary" - TypeNameFixedChar TypeName = "fixedchar" - TypeNameVarChar TypeName = "varchar" - TypeNameDecimal TypeName = "decimal" + TypeNameFixedBinary TypeName = "fixedbinary" + TypeNameFixedChar TypeName = "fixedchar" + TypeNameVarChar TypeName = "varchar" + TypeNameDecimal TypeName = "decimal" + TypeNamePrecisionTimestamp TypeName = "precision_timestamp" + TypeNamePrecisionTimestampTz TypeName = "precision_timestamp_tz" ) var simpleTypeNameMap = map[TypeName]Type{ @@ -354,7 +356,7 @@ type ( } // Type corresponds to the proto.Type message and represents - // a specific type. + // a specific type. These are types which can be present in plan (are serializable) Type interface { FuncArg isRootRef() @@ -369,14 +371,33 @@ type ( WithNullability(Nullability) Type } - ParameterizedType interface { + // CompositeType this represents a concrete type having components + CompositeType interface { Type + // ParameterString this returns parameter string + // for e.g. parameter decimal, ParameterString returns "P,S" ParameterString() string + // BaseString this returns long name for parameter string + // for e.g. parameter decimal, BaseString returns "decimal" BaseString() string } + // FuncDefArgType this represents a type used in function argument + // These type can't be present in plan (not serializable) + FuncDefArgType interface { + fmt.Stringer + //SetNullability set nullability as given argument + SetNullability(Nullability) FuncDefArgType + // HasParameterizedParam returns true if the type has at least one parameterized parameters + // if all parameters are concrete then it returns false + HasParameterizedParam() bool + // GetParameterizedParams returns all parameterized parameters + // it doesn't return concrete parameters + GetParameterizedParams() []interface{} + } + FixedType interface { - ParameterizedType + CompositeType WithLength(int32) FixedType } ) @@ -523,6 +544,8 @@ var typeNames = map[reflect.Type]string{ reflect.TypeOf(&FixedBinary{}): "fixedbinary", reflect.TypeOf(&emptyFixedChar): "char", reflect.TypeOf(&VarChar{}): "varchar", + reflect.TypeOf(&PrecisionTimestampType{}): "precision_timestamp", + reflect.TypeOf(&PrecisionTimestampTzType{}): "precision_timestamp_tz", } var shortNames = map[reflect.Type]string{ @@ -548,7 +571,11 @@ var shortNames = map[reflect.Type]string{ } func strNullable(t Type) string { - if t.GetNullability() == NullabilityNullable { + return strFromNullability(t.GetNullability()) +} + +func strFromNullability(nullability Nullability) string { + if nullability == NullabilityNullable { return "?" } return "" @@ -602,27 +629,47 @@ func (s *PrimitiveType[T]) String() string { return reflect.TypeOf(z).Elem().Name() + strNullable(s) } +func (s *PrimitiveType[T]) HasParameterizedParam() bool { + // primitive type doesn't have abstract parameters + return false +} + +func (s *PrimitiveType[T]) GetParameterizedParams() []interface{} { + // primitive type doesn't have any abstract parameters + return nil +} + +func (s *PrimitiveType[T]) SetNullability(n Nullability) FuncDefArgType { + s.Nullability = n + return s +} + // create type aliases to the generic structs type ( - BooleanType = PrimitiveType[bool] - Int8Type = PrimitiveType[int8] - Int16Type = PrimitiveType[int16] - Int32Type = PrimitiveType[int32] - Int64Type = PrimitiveType[int64] - Float32Type = PrimitiveType[float32] - Float64Type = PrimitiveType[float64] - StringType = PrimitiveType[string] - BinaryType = PrimitiveType[[]byte] - TimestampType = PrimitiveType[Timestamp] - DateType = PrimitiveType[Date] - TimeType = PrimitiveType[Time] - TimestampTzType = PrimitiveType[TimestampTz] - IntervalYearType = PrimitiveType[IntervalYearToMonth] - IntervalDayType = PrimitiveType[IntervalDayToSecond] - UUIDType = PrimitiveType[UUID] - FixedCharType = FixedLenType[FixedChar] - VarCharType = FixedLenType[VarChar] - FixedBinaryType = FixedLenType[FixedBinary] + BooleanType = PrimitiveType[bool] + Int8Type = PrimitiveType[int8] + Int16Type = PrimitiveType[int16] + Int32Type = PrimitiveType[int32] + Int64Type = PrimitiveType[int64] + Float32Type = PrimitiveType[float32] + Float64Type = PrimitiveType[float64] + StringType = PrimitiveType[string] + BinaryType = PrimitiveType[[]byte] + TimestampType = PrimitiveType[Timestamp] + DateType = PrimitiveType[Date] + TimeType = PrimitiveType[Time] + TimestampTzType = PrimitiveType[TimestampTz] + IntervalYearType = PrimitiveType[IntervalYearToMonth] + IntervalDayType = PrimitiveType[IntervalDayToSecond] + UUIDType = PrimitiveType[UUID] + FixedCharType = FixedLenType[FixedChar] + VarCharType = FixedLenType[VarChar] + FixedBinaryType = FixedLenType[FixedBinary] + ParameterizedVarCharType = parameterizedTypeSingleIntegerParam[*VarCharType] + ParameterizedFixedCharType = parameterizedTypeSingleIntegerParam[*FixedCharType] + ParameterizedFixedBinaryType = parameterizedTypeSingleIntegerParam[*FixedBinaryType] + ParameterizedPrecisionTimestampType = parameterizedTypeSingleIntegerParam[*PrecisionTimestampType] + ParameterizedPrecisionTimestampTzType = parameterizedTypeSingleIntegerParam[*PrecisionTimestampTzType] ) // FixedLenType is any of the types which also need to track their specific @@ -683,6 +730,7 @@ func (s *FixedLenType[T]) WithLength(length int32) FixedType { return &out } +// DecimalType is a decimal type with concrete precision and scale parameters, e.g. Decimal(10, 2). type DecimalType struct { Nullability Nullability TypeVariationRef uint32 @@ -808,6 +856,21 @@ func (t *StructType) String() string { return b.String() } +func (t *StructType) ParameterString() string { + sb := strings.Builder{} + for i, typ := range t.Types { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(typ.String()) + } + return sb.String() +} + +func (*StructType) BaseString() string { + return "struct" +} + type ListType struct { Nullability Nullability TypeVariationRef uint32 @@ -859,6 +922,14 @@ func (t *ListType) String() string { return "list" + strNullable(t) + "<" + t.Type.String() + ">" } +func (s *ListType) ParameterString() string { + return s.Type.String() +} + +func (*ListType) BaseString() string { + return "list" +} + type MapType struct { Nullability Nullability TypeVariationRef uint32 @@ -907,7 +978,15 @@ func (t *MapType) ToProtoFuncArg() *proto.FunctionArgument { func (t *MapType) ShortString() string { return "map" } func (t *MapType) String() string { - return "map" + strNullable(t) + "<" + t.Key.String() + "," + t.Value.String() + ">" + return "map" + strNullable(t) + "<" + t.Key.String() + ", " + t.Value.String() + ">" +} + +func (t *MapType) ParameterString() string { + return fmt.Sprintf("%s, %s", t.Key.String(), t.Value.String()) +} + +func (*MapType) BaseString() string { + return "map" } // TypeParam represents a type parameter for a user defined type diff --git a/types/types_test.go b/types/types_test.go index f705c8b..49e6e8c 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -42,7 +42,7 @@ func TestTypeToString(t *testing.T) { "struct?>", "struct"}, {&ListType{Type: &Int8Type{}}, "list", "list"}, {&MapType{Key: &StringType{}, Value: &DecimalType{Precision: 10, Scale: 2}}, - "map>", "map"}, + "map>", "map"}, } for _, tt := range tests {