Skip to content

Commit

Permalink
Add Match and MatchAt API to FunctionVariant interface
Browse files Browse the repository at this point in the history
* Also fixed default nullability for ScalarFunctionImpl
  • Loading branch information
anshuldata committed Sep 12, 2024
1 parent 2240ec9 commit 94687c0
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 6 deletions.
4 changes: 4 additions & 0 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"
"sort"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/proto/extensions"
Expand Down Expand Up @@ -179,6 +180,9 @@ func (c *Collection) Load(uri string, r io.Reader) error {
simpleNames := make(map[string]string)

for _, f := range file.ScalarFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for scalar functions: %w", err)

Check warning on line 184 in extensions/extension_mgr.go

View check run for this annotation

Codecov / codecov/patch

extensions/extension_mgr.go#L184

Added line #L184 was not covered by tests
}
addToMaps[*ScalarFunctionVariant](id, &f, c.scalarMap, simpleNames)
}

Expand Down
2 changes: 2 additions & 0 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
Expand Down
2 changes: 1 addition & 1 deletion extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ type ScalarFunctionImpl struct {
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}
Expand Down
105 changes: 105 additions & 0 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type FunctionVariant interface {
URI() string
ResolveType(argTypes []types.Type) (types.Type, error)
Variadic() *VariadicBehavior
Match(argumentTypes []types.Type) (bool, error)
MatchAt(typ types.Type, pos int) (bool, error)
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
Expand Down Expand Up @@ -84,6 +86,78 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx
return outType, nil
}

// TODO: Handle Variadic function
func matchArguments(nullability NullabilityHandling, paramTypeList ArgumentList, actualTypes []types.Type) (bool, error) {
if len(paramTypeList) != len(actualTypes) {
return false, nil

Check warning on line 92 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L90-L92

Added lines #L90 - L92 were not covered by tests
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil

Check warning on line 96 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L94-L96

Added lines #L94 - L96 were not covered by tests
}
for argPos := range paramTypeList {
match, err1 := matchArgumentAtCommon(actualTypes[argPos], argPos, nullability, funcDefArgList)
if err1 != nil {
return false, err1

Check warning on line 101 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L98-L101

Added lines #L98 - L101 were not covered by tests
}
if !match {
return false, nil

Check warning on line 104 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L103-L104

Added lines #L103 - L104 were not covered by tests
}
}
return true, nil

Check warning on line 107 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L107

Added line #L107 was not covered by tests
}

// TODO: Handle Variadic function
func matchArgumentAt(actualType types.Type, argPos int, nullability NullabilityHandling, paramTypeList ArgumentList) (bool, error) {
if argPos < 0 {
return false, fmt.Errorf("non-zero argument position")

Check warning on line 113 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L111-L113

Added lines #L111 - L113 were not covered by tests
}
if argPos >= len(paramTypeList) {
return false, fmt.Errorf("%w: argument position %d out of range", substraitgo.ErrNotFound, argPos)

Check warning on line 116 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L115-L116

Added lines #L115 - L116 were not covered by tests
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil

Check warning on line 120 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L118-L120

Added lines #L118 - L120 were not covered by tests
}
return matchArgumentAtCommon(actualType, argPos, nullability, funcDefArgList)

Check warning on line 122 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L122

Added line #L122 was not covered by tests
}

func matchArgumentAtCommon(actualType types.Type, argPos int, nullability NullabilityHandling, funcDefArgList []types.FuncDefArgType) (bool, error) {
if HasSyncParams(funcDefArgList) {
return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented)

Check warning on line 127 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L125-L127

Added lines #L125 - L127 were not covered by tests
}
funcDefArg := funcDefArgList[argPos]
switch nullability {
case DiscreteNullability:
return funcDefArg.MatchWithNullability(actualType), nil
case MirrorNullability, DeclaredOutputNullability:
return funcDefArg.MatchWithoutNullability(actualType), nil

Check warning on line 134 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L129-L134

Added lines #L129 - L134 were not covered by tests
}
// unreachable case
return false, fmt.Errorf("invalid nullability type: %s", nullability)

Check warning on line 137 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L137

Added line #L137 was not covered by tests
}

func getFuncDefFromArgList(paramTypeList ArgumentList) ([]types.FuncDefArgType, error) {
var out []types.FuncDefArgType
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err

Check warning on line 147 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L140-L147

Added lines #L140 - L147 were not covered by tests
}
out = append(out, funcDefArgType)
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
default:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)

Check warning on line 155 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L149-L155

Added lines #L149 - L155 were not covered by tests
}
}
return out, nil

Check warning on line 158 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L158

Added line #L158 was not covered by tests
}

func parseFuncName(compoundName string) (name string, args ArgumentList) {
name, argsStr, _ := strings.Cut(compoundName, ":")
if len(argsStr) == 0 {
Expand All @@ -102,6 +176,25 @@ func parseFuncName(compoundName string) (name string, args ArgumentList) {
return name, args
}

// Match this function matches input arguments against definition of this functions argument list
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input argument type and no match this function returns
func (s *ScalarFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)

Check warning on line 185 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L184-L185

Added lines #L184 - L185 were not covered by tests
}

// MatchAt this function matches input argument at position against definition of this
// functions argument at same position
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input argument type and no match this function returns
func (s *ScalarFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)

Check warning on line 195 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L194-L195

Added lines #L194 - L195 were not covered by tests
}

// NewScalarFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -268,6 +361,12 @@ func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error)
}
func (s *AggregateFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *AggregateFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *AggregateFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)

Check warning on line 365 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L364-L365

Added lines #L364 - L365 were not covered by tests
}
func (s *AggregateFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)

Check warning on line 368 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L367-L368

Added lines #L367 - L368 were not covered by tests
}

type WindowFunctionVariant struct {
name string
Expand Down Expand Up @@ -376,6 +475,12 @@ func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType }
func (s *WindowFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)

Check warning on line 479 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L478-L479

Added lines #L478 - L479 were not covered by tests
}
func (s *WindowFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)

Check warning on line 482 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L481-L482

Added lines #L481 - L482 were not covered by tests
}

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
Expand Down
Loading

0 comments on commit 94687c0

Please sign in to comment.