Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Parameterized type #52

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,13 @@
func (w *WindowFunction) Decomposable() extensions.DecomposeType {
return w.declaration.Decomposability()
}
func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() }
func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() }
func (w *WindowFunction) IntermediateType() (types.Type, error) { return w.declaration.Intermediate() }
func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() }
func (*WindowFunction) IsScalar() bool { return false }
func (w *WindowFunction) Ordered() bool { return w.declaration.Ordered() }
func (w *WindowFunction) MaxSet() int { return w.declaration.MaxSet() }
func (w *WindowFunction) IntermediateType() (types.FuncDefArgType, error) {
return w.declaration.Intermediate()

Check warning on line 484 in expr/functions.go

View check run for this annotation

Codecov / codecov/patch

expr/functions.go#L481-L484

Added lines #L481 - L484 were not covered by tests
}
func (w *WindowFunction) WindowType() extensions.WindowType { return w.declaration.WindowType() }
func (*WindowFunction) IsScalar() bool { return false }

Check warning on line 487 in expr/functions.go

View check run for this annotation

Codecov / codecov/patch

expr/functions.go#L486-L487

Added lines #L486 - L487 were not covered by tests

func (*WindowFunction) isRootRef() {}

Expand Down Expand Up @@ -773,7 +775,7 @@
}
func (a *AggregateFunction) Ordered() bool { return a.declaration.Ordered() }
func (a *AggregateFunction) MaxSet() int { return a.declaration.MaxSet() }
func (a *AggregateFunction) IntermediateType() (types.Type, error) {
func (a *AggregateFunction) IntermediateType() (types.FuncDefArgType, error) {

Check warning on line 778 in expr/functions.go

View check run for this annotation

Codecov / codecov/patch

expr/functions.go#L778

Added line #L778 was not covered by tests
return a.declaration.Intermediate()
}

Expand Down
2 changes: 1 addition & 1 deletion expr/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestLiteralToString(t *testing.T) {
Value: expr.NewFixedCharLiteral(types.FixedChar("bar"), false),
},
}, true),
}, true), "list?<map?<string,char<3>>>([map?<string,char<3>>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"},
}, true), "list?<map?<string, char<3>>>([map?<string, char<3>>([{string(foo) char<3>(bar)} {string(baz) char<3>(bar)}])])"},
{MustLiteral(expr.NewLiteral(float32(1.5), false)), "fp32(1.5)"},
{MustLiteral(expr.NewLiteral(&types.VarChar{Value: "foobar", Length: 7}, true)), "varchar?<7>(foobar)"},
{expr.NewPrecisionTimestampLiteral(123456, types.PrecisionSeconds, types.NullabilityNullable), "precisiontimestamp?<0>(123456)"},
Expand Down
7 changes: 7 additions & 0 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@

type Argument interface {
toTypeString() string
argumentMarker() // unexported marker method
}

type EnumArg struct {
Expand All @@ -69,6 +70,8 @@
return "req"
}

func (v EnumArg) argumentMarker() {}

Check warning on line 73 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L73

Added line #L73 was not covered by tests

type ValueArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -80,6 +83,8 @@
return v.Value.Expr.(*parser.Type).ShortType()
}

func (v ValueArg) argumentMarker() {}

Check warning on line 86 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L86

Added line #L86 was not covered by tests

type TypeArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -88,6 +93,8 @@

func (TypeArg) toTypeString() string { return "type" }

func (v TypeArg) argumentMarker() {}

Check warning on line 96 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L96

Added line #L96 was not covered by tests

type ArgumentList []Argument

func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
Expand Down
64 changes: 59 additions & 5 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/integer_parameters"
"github.com/substrait-io/substrait-go/types/parser"
)

Expand Down Expand Up @@ -65,7 +66,7 @@
var outType types.Type
if t, ok := expr.Expr.(*parser.Type); ok {
var err error
outType, err = t.Type()
outType, err = t.RetType()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -259,9 +260,9 @@
return ID{URI: s.uri, Name: s.CompoundName()}
}
func (s *AggregateFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *AggregateFunctionVariant) Intermediate() (types.Type, error) {
func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.Type()
return t.ArgType()
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
Expand Down Expand Up @@ -366,12 +367,65 @@
return ID{URI: s.uri, Name: s.CompoundName()}
}
func (s *WindowFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *WindowFunctionVariant) Intermediate() (types.Type, error) {
func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {

Check warning on line 370 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L370

Added line #L370 was not covered by tests
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.Type()
return t.ArgType()

Check warning on line 372 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L372

Added line #L372 was not covered by tests
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType }

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
// if any of the leaf parameters are same, it indicates parameters are same across parameters
existingParamMap := make(map[string]bool)
for _, p := range params {
if !p.HasParameterizedParam() {
// not a type which contains abstract parameters, so continue
continue
}
// get list of parameterized parameters
// parameterized param can be a Leaf or another type. If another type we recurse to find leaf
abstractParams := p.GetParameterizedParams()
var leafParams []string
for _, abstractParam := range abstractParams {
leafParams = append(leafParams, getLeafParameterizedParams(abstractParam)...)
}
// if map contains any of the leaf params, parameters are synced
for _, leafParam := range leafParams {
if _, ok := existingParamMap[leafParam]; ok {
return true
}
}
// add all params to map, kindly note we can't add these params
// in previous loop to avoid having same leaf abstract type in same param
// e.g. Decimal<P, P> has no sync param
for _, leafParam := range leafParams {
existingParamMap[leafParam] = true
}
}
return false
}

// from a parameterized type, get the leaf parameters
// an parameterized param can be a leaf type (e.g. P) or a parameterized type (e.g. VARCHAR<L1>) itself
// if it is a leaf type, its param name is returned
// if it is parameterized type, leaf type is found recursively
func getLeafParameterizedParams(abstractTypes interface{}) []string {
if leaf, ok := abstractTypes.(integer_parameters.IntegerParameter); ok {
return []string{leaf.String()}
}
// if it is not a leaf type recurse
if pat, ok := abstractTypes.(types.FuncDefArgType); ok {
var outLeafParams []string
for _, p := range pat.GetParameterizedParams() {
childLeafParams := getLeafParameterizedParams(p)
outLeafParams = append(outLeafParams, childLeafParams...)
}
return outLeafParams
}
// invalid type
panic("invalid non-leaf, non-parameterized type param")

Check warning on line 430 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L430

Added line #L430 was not covered by tests
}
41 changes: 41 additions & 0 deletions extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/extensions"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/integer_parameters"
"github.com/substrait-io/substrait-go/types/parser"
)

Expand Down Expand Up @@ -65,3 +67,42 @@ func TestEvaluateTypeExpression(t *testing.T) {
})
}
}

func TestHasSyncParams(t *testing.T) {

apt_P := integer_parameters.NewVariableIntParam("P")
apt_Q := integer_parameters.NewVariableIntParam("Q")
cpt_38 := integer_parameters.NewConcreteIntParam(38)

fct_P := &types.ParameterizedFixedCharType{IntegerOption: apt_P}
fct_Q := &types.ParameterizedFixedCharType{IntegerOption: apt_Q}
decimal_PQ := &types.ParameterizedDecimalType{Precision: apt_P, Scale: apt_Q}
decimal_38_Q := &types.ParameterizedDecimalType{Precision: cpt_38, Scale: apt_Q}
list_decimal_38_Q := &types.ParameterizedListType{Type: decimal_38_Q}
map_fctQ_decimal38Q := &types.ParameterizedMapType{Key: fct_Q, Value: decimal_38_Q}
struct_fctQ_ListDecimal38Q := &types.ParameterizedStructType{Types: []types.FuncDefArgType{fct_Q, list_decimal_38_Q}}
for _, td := range []struct {
name string
params []types.FuncDefArgType
expectedHasSyncParams bool
}{
{"No Abstract Type", []types.FuncDefArgType{&types.Int64Type{}}, false},
{"No Sync Param P, Q", []types.FuncDefArgType{fct_P, fct_Q}, false},
{"Sync Params P, P", []types.FuncDefArgType{fct_P, fct_P}, true},
{"Sync Params P, <P, Q>", []types.FuncDefArgType{fct_P, decimal_PQ}, true},
{"No Sync Params P, <38, Q>", []types.FuncDefArgType{fct_P, decimal_38_Q}, false},
{"Sync Params P, List<Decimal<P, Q>>", []types.FuncDefArgType{fct_P, list_decimal_38_Q}, false},
{"No Sync Params fct<P>, Map<fct<Q>, decimal<38,Q>>", []types.FuncDefArgType{fct_P, map_fctQ_decimal38Q}, false},
{"Sync Params fct<Q>, Map<fct<Q>, decimal<38,Q>>", []types.FuncDefArgType{fct_Q, map_fctQ_decimal38Q}, true},
{"No Sync Params fct<P>, struct<fct<Q>, list<38,Q>>", []types.FuncDefArgType{fct_P, struct_fctQ_ListDecimal38Q}, false},
{"Sync Params fct<Q>, struct<fct<Q>, list<38,Q>>", []types.FuncDefArgType{fct_Q, struct_fctQ_ListDecimal38Q}, true},
} {
t.Run(td.name, func(t *testing.T) {
if td.expectedHasSyncParams {
require.True(t, extensions.HasSyncParams(td.params))
} else {
require.False(t, extensions.HasSyncParams(td.params))
}
})
}
}
8 changes: 5 additions & 3 deletions functions/types.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// SPDX-License-Identifier: Apache-2.0

package functions

import (
Expand Down Expand Up @@ -147,14 +149,14 @@ type typeInfo struct {

func (ti *typeInfo) getLongName() string {
switch ti.typ.(type) {
case types.ParameterizedType:
return ti.typ.(types.ParameterizedType).BaseString()
case types.CompositeType:
return ti.typ.(types.CompositeType).BaseString()
}
return ti.typ.String()
}

func (ti *typeInfo) getLocalTypeString(input types.Type, enclosure typeEnclosure) string {
if paramType, ok := input.(types.ParameterizedType); ok {
if paramType, ok := input.(types.CompositeType); ok {
return ti.localName + enclosure.containerStart() + paramType.ParameterString() + enclosure.containerEnd()
}
return ti.localName
Expand Down
34 changes: 34 additions & 0 deletions types/any_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-License-Identifier: Apache-2.0

package types
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

import (
"fmt"
)

// AnyType to represent AnyType, this type is to indicate "any" type of argument
// This type is not used in function invocation. It is only used in function definition
type AnyType struct {
Name string
TypeVariationRef uint32
Nullability Nullability
}

func (t AnyType) SetNullability(n Nullability) FuncDefArgType {
t.Nullability = n
return t
}

func (t AnyType) String() string {
return fmt.Sprintf("%s%s", t.Name, strFromNullability(t.Nullability))
}

func (s AnyType) HasParameterizedParam() bool {

Check warning on line 26 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L26

Added line #L26 was not covered by tests
// primitive type doesn't have abstract parameters
return false

Check warning on line 28 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L28

Added line #L28 was not covered by tests
}

func (s AnyType) GetParameterizedParams() []interface{} {

Check warning on line 31 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L31

Added line #L31 was not covered by tests
// any type doesn't have any abstract parameters
return nil

Check warning on line 33 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L33

Added line #L33 was not covered by tests
}
31 changes: 31 additions & 0 deletions types/any_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package types_test
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestAnyType(t *testing.T) {
for _, td := range []struct {
testName string
argName string
nullability types.Nullability
expectedString string
}{
{"any", "any", types.NullabilityNullable, "any?"},
{"anyrequired", "any", types.NullabilityRequired, "any"},
{"anyOtherName", "any1", types.NullabilityNullable, "any1?"},
{"T name", "T", types.NullabilityNullable, "T?"},
} {
t.Run(td.testName, func(t *testing.T) {
arg := &types.AnyType{
Name: td.argName,
Nullability: td.nullability,
}
anyType := arg.SetNullability(td.nullability)
require.Equal(t, td.expectedString, anyType.String())
})
}
}
26 changes: 26 additions & 0 deletions types/integer_parameters/concrete_int_param.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: Apache-2.0

package integer_parameters

import "fmt"

// ConcreteIntParam represents a single integer concrete parameter for a concrete type
// Example: VARCHAR(6) -> 6 is an ConcreteIntParam
// DECIMAL<P, 0> --> 0 Is an ConcreteIntParam but P not
type ConcreteIntParam int32

func NewConcreteIntParam(v int32) IntegerParameter {
m := ConcreteIntParam(v)
return &m
}

func (m *ConcreteIntParam) IsCompatible(o IntegerParameter) bool {
if t, ok := o.(*ConcreteIntParam); ok {
return t == m
}
return false

Check warning on line 21 in types/integer_parameters/concrete_int_param.go

View check run for this annotation

Codecov / codecov/patch

types/integer_parameters/concrete_int_param.go#L21

Added line #L21 was not covered by tests
}

func (m *ConcreteIntParam) String() string {
return fmt.Sprintf("%d", *m)
}
15 changes: 15 additions & 0 deletions types/integer_parameters/integer_parameter_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-License-Identifier: Apache-2.0

package integer_parameters

import "fmt"

// IntegerParameter represents a parameter type
// parameter can of concrete (38) or abstract type (P)
// or another parameterized type like VARCHAR<"L1">
type IntegerParameter interface {
// IsCompatible is type compatible with other
// compatible is other can be used in place of this type
IsCompatible(other IntegerParameter) bool
fmt.Stringer
}
36 changes: 36 additions & 0 deletions types/integer_parameters/integer_parameter_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: Apache-2.0

package integer_parameters_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types/integer_parameters"
)

func TestConcreteParameterType(t *testing.T) {
concreteType1 := integer_parameters.ConcreteIntParam(1)
require.Equal(t, "1", concreteType1.String())
}

func TestLeafParameterType(t *testing.T) {
var concreteType1, concreteType2, abstractType1 integer_parameters.IntegerParameter

concreteType1 = integer_parameters.NewConcreteIntParam(1)
concreteType2 = integer_parameters.NewConcreteIntParam(2)

abstractType1 = integer_parameters.NewVariableIntParam("P")

// verify string val
require.Equal(t, "1", concreteType1.String())
require.Equal(t, "P", abstractType1.String())

// concrete type is only compatible with same type
require.True(t, concreteType1.IsCompatible(concreteType1))
require.False(t, concreteType1.IsCompatible(concreteType2))

// abstract type is compatible with both abstract and concrete type
require.True(t, abstractType1.IsCompatible(abstractType1))
require.True(t, abstractType1.IsCompatible(concreteType2))
}
Loading
Loading