Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Match and MatchAt API to FunctionVariant interface #54

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"path"
"sort"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/proto/extensions"
Expand Down Expand Up @@ -179,14 +180,23 @@
simpleNames := make(map[string]string)

for _, f := range file.ScalarFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for scalar functions: %w", err)

Check warning on line 184 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L184

Added line #L184 was not covered by tests
}
addToMaps[*ScalarFunctionVariant](id, &f, c.scalarMap, simpleNames)
}

for _, f := range file.AggregateFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for aggregate functions: %w", err)

Check warning on line 191 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L191

Added line #L191 was not covered by tests
}
addToMaps[*AggregateFunctionVariant](id, &f, c.aggregateMap, simpleNames)
}

for _, f := range file.WindowFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for window functions: %w", err)

Check warning on line 198 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L198

Added line #L198 was not covered by tests
}
addToMaps[*WindowFunctionVariant](id, &f, c.windowMap, simpleNames)
}

Expand Down
2 changes: 2 additions & 0 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
Expand Down
4 changes: 2 additions & 2 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ const (
type VariadicBehavior struct {
Min int `yaml:",omitempty"`
Max int `yaml:",omitempty"`
ParameterConsistency ParameterConsistency `yaml:"parameterConsistency,omitempty"`
ParameterConsistency ParameterConsistency `yaml:"parameterConsistency,omitempty" default:"CONSISTENT"`
}

func (v *VariadicBehavior) IsValidArgumentCount(count int) bool {
Expand All @@ -205,7 +205,7 @@ type ScalarFunctionImpl struct {
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}
Expand Down
136 changes: 136 additions & 0 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@
URI() string
ResolveType(argTypes []types.Type) (types.Type, error)
Variadic() *VariadicBehavior
// Match this function matches input arguments against this functions parameter list
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input arguments and argument list type replace parameter list
Match(argumentTypes []types.Type) (bool, error)
// MatchAt this function matches input argument at position against definition of this
// functions argument at same position
// returns (true, nil) if input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input arg position is negative or
// argument nullability is not correctly set this function will return error
// returns (false, nil) valid input argument type and argument can't type replace parameter at argPos
MatchAt(typ types.Type, pos int) (bool, error)
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
Expand Down Expand Up @@ -84,6 +97,109 @@
return outType, nil
}

func matchArguments(nullability NullabilityHandling, paramTypeList ArgumentList, variadicBehavior *VariadicBehavior, actualTypes []types.Type) (bool, error) {
if variadicBehavior == nil && len(actualTypes) != len(paramTypeList) {
return false, nil
} else if variadicBehavior != nil && !validateVariadicBehaviorForMatch(variadicBehavior, actualTypes) {
return false, nil

Check warning on line 104 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L100-L104

Added lines #L100 - L104 were not covered by tests
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil

Check warning on line 108 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L106-L108

Added lines #L106 - L108 were not covered by tests
}
// loop over actualTypes and not params since actualTypes can be more than params
// considering variadic type
for argPos := range actualTypes {
match, err1 := matchArgumentAtCommon(actualTypes[argPos], argPos, nullability, funcDefArgList, variadicBehavior)
if err1 != nil {
return false, err1

Check warning on line 115 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L112-L115

Added lines #L112 - L115 were not covered by tests
}
if !match {
return false, nil

Check warning on line 118 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L117-L118

Added lines #L117 - L118 were not covered by tests
}
}
return true, nil

Check warning on line 121 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L121

Added line #L121 was not covered by tests
}

func matchArgumentAt(actualType types.Type, argPos int, nullability NullabilityHandling, paramTypeList ArgumentList, variadicBehavior *VariadicBehavior) (bool, error) {
if argPos < 0 {
return false, fmt.Errorf("non-zero argument position")

Check warning on line 126 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L124-L126

Added lines #L124 - L126 were not covered by tests
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil

Check warning on line 130 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L128-L130

Added lines #L128 - L130 were not covered by tests
}
return matchArgumentAtCommon(actualType, argPos, nullability, funcDefArgList, variadicBehavior)

Check warning on line 132 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L132

Added line #L132 was not covered by tests
}

func matchArgumentAtCommon(actualType types.Type, argPos int, nullability NullabilityHandling, funcDefArgList []types.FuncDefArgType, variadicBehavior *VariadicBehavior) (bool, error) {

Check warning on line 135 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L135

Added line #L135 was not covered by tests
// check if argument out of range
if variadicBehavior == nil && argPos >= len(funcDefArgList) {
return false, nil
} else if variadicBehavior != nil && !variadicBehavior.IsValidArgumentCount(argPos+1) {

Check warning on line 139 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L137-L139

Added lines #L137 - L139 were not covered by tests
// this argument position can't be more than the max allowed by the variadic behavior
return false, nil

Check warning on line 141 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L141

Added line #L141 was not covered by tests
}

if HasSyncParams(funcDefArgList) {
return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented)

Check warning on line 145 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L144-L145

Added lines #L144 - L145 were not covered by tests
}
// if argPos is >= len(funcDefArgList) than last funcDefArg type should be considered for type match
// already checked for parameter in range above (considering variadic) so no need to check again for variadic
var funcDefArg types.FuncDefArgType
if argPos < len(funcDefArgList) {
funcDefArg = funcDefArgList[argPos]
} else {
funcDefArg = funcDefArgList[len(funcDefArgList)-1]

Check warning on line 153 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L149-L153

Added lines #L149 - L153 were not covered by tests
}
switch nullability {
case DiscreteNullability:
return funcDefArg.MatchWithNullability(actualType), nil
case MirrorNullability, DeclaredOutputNullability:
return funcDefArg.MatchWithoutNullability(actualType), nil

Check warning on line 159 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L155-L159

Added lines #L155 - L159 were not covered by tests
}
// unreachable case
return false, fmt.Errorf("invalid nullability type: %s", nullability)

Check warning on line 162 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L162

Added line #L162 was not covered by tests
}

func validateVariadicBehaviorForMatch(variadicBehavior *VariadicBehavior, actualTypes []types.Type) bool {
if !variadicBehavior.IsValidArgumentCount(len(actualTypes)) {
return false

Check warning on line 167 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L165-L167

Added lines #L165 - L167 were not covered by tests
}
// verify consistency of variadic behavior
if variadicBehavior.ParameterConsistency == ConsistentParams {

Check warning on line 170 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L170

Added line #L170 was not covered by tests
// all concrete types must be equal for all variable arguments
firstVariadicArgIdx := variadicBehavior.Min - 1
for i := firstVariadicArgIdx; i < len(actualTypes)-1; i++ {
if !actualTypes[i].Equals(actualTypes[i+1]) {
return false

Check warning on line 175 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L172-L175

Added lines #L172 - L175 were not covered by tests
}
}
}
return true

Check warning on line 179 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L179

Added line #L179 was not covered by tests
}

func getFuncDefFromArgList(paramTypeList ArgumentList) ([]types.FuncDefArgType, error) {
var out []types.FuncDefArgType
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err

Check warning on line 189 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L182-L189

Added lines #L182 - L189 were not covered by tests
}
out = append(out, funcDefArgType)
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
default:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)

Check warning on line 197 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L191-L197

Added lines #L191 - L197 were not covered by tests
}
}
return out, nil

Check warning on line 200 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L200

Added line #L200 was not covered by tests
}

func parseFuncName(compoundName string) (name string, args ArgumentList) {
name, argsStr, _ := strings.Cut(compoundName, ":")
if len(argsStr) == 0 {
Expand Down Expand Up @@ -160,6 +276,14 @@
return ID{URI: s.uri, Name: s.CompoundName()}
}

func (s *ScalarFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)

Check warning on line 280 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L279-L280

Added lines #L279 - L280 were not covered by tests
}

func (s *ScalarFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)

Check warning on line 284 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L283-L284

Added lines #L283 - L284 were not covered by tests
}

// NewAggFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -268,6 +392,12 @@
}
func (s *AggregateFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *AggregateFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *AggregateFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)

Check warning on line 396 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L395-L396

Added lines #L395 - L396 were not covered by tests
}
func (s *AggregateFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)

Check warning on line 399 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L398-L399

Added lines #L398 - L399 were not covered by tests
}

type WindowFunctionVariant struct {
name string
Expand Down Expand Up @@ -376,6 +506,12 @@
func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType }
func (s *WindowFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)

Check warning on line 510 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L509-L510

Added lines #L509 - L510 were not covered by tests
}
func (s *WindowFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)

Check warning on line 513 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L512-L513

Added lines #L512 - L513 were not covered by tests
}

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
Expand Down
Loading
Loading