Skip to content

Commit

Permalink
feat: Add Match() and MatchAt() to FunctionVariant (#54)
Browse files Browse the repository at this point in the history
* Also fixed default nullability for ScalarFunctionImpl
  • Loading branch information
anshuldata authored Sep 14, 2024
1 parent c48fb53 commit 1a800d9
Show file tree
Hide file tree
Showing 7 changed files with 735 additions and 7 deletions.
10 changes: 10 additions & 0 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"
"sort"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/proto/extensions"
Expand Down Expand Up @@ -179,14 +180,23 @@ func (c *Collection) Load(uri string, r io.Reader) error {
simpleNames := make(map[string]string)

for _, f := range file.ScalarFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for scalar functions: %w", err)
}
addToMaps[*ScalarFunctionVariant](id, &f, c.scalarMap, simpleNames)
}

for _, f := range file.AggregateFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for aggregate functions: %w", err)
}
addToMaps[*AggregateFunctionVariant](id, &f, c.aggregateMap, simpleNames)
}

for _, f := range file.WindowFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for window functions: %w", err)
}
addToMaps[*WindowFunctionVariant](id, &f, c.windowMap, simpleNames)
}

Expand Down
2 changes: 2 additions & 0 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
Expand Down
4 changes: 2 additions & 2 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ const (
type VariadicBehavior struct {
Min int `yaml:",omitempty"`
Max int `yaml:",omitempty"`
ParameterConsistency ParameterConsistency `yaml:"parameterConsistency,omitempty"`
ParameterConsistency ParameterConsistency `yaml:"parameterConsistency,omitempty" default:"CONSISTENT"`
}

func (v *VariadicBehavior) IsValidArgumentCount(count int) bool {
Expand All @@ -205,7 +205,7 @@ type ScalarFunctionImpl struct {
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}
Expand Down
136 changes: 136 additions & 0 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ type FunctionVariant interface {
URI() string
ResolveType(argTypes []types.Type) (types.Type, error)
Variadic() *VariadicBehavior
// Match this function matches input arguments against this functions parameter list
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input arguments and argument list type replace parameter list
Match(argumentTypes []types.Type) (bool, error)
// MatchAt this function matches input argument at position against definition of this
// functions argument at same position
// returns (true, nil) if input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input arg position is negative or
// argument nullability is not correctly set this function will return error
// returns (false, nil) valid input argument type and argument can't type replace parameter at argPos
MatchAt(typ types.Type, pos int) (bool, error)
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
Expand Down Expand Up @@ -84,6 +97,109 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx
return outType, nil
}

func matchArguments(nullability NullabilityHandling, paramTypeList ArgumentList, variadicBehavior *VariadicBehavior, actualTypes []types.Type) (bool, error) {
if variadicBehavior == nil && len(actualTypes) != len(paramTypeList) {
return false, nil
} else if variadicBehavior != nil && !validateVariadicBehaviorForMatch(variadicBehavior, actualTypes) {
return false, nil
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil
}
// loop over actualTypes and not params since actualTypes can be more than params
// considering variadic type
for argPos := range actualTypes {
match, err1 := matchArgumentAtCommon(actualTypes[argPos], argPos, nullability, funcDefArgList, variadicBehavior)
if err1 != nil {
return false, err1
}
if !match {
return false, nil
}
}
return true, nil
}

func matchArgumentAt(actualType types.Type, argPos int, nullability NullabilityHandling, paramTypeList ArgumentList, variadicBehavior *VariadicBehavior) (bool, error) {
if argPos < 0 {
return false, fmt.Errorf("non-zero argument position")
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil
}
return matchArgumentAtCommon(actualType, argPos, nullability, funcDefArgList, variadicBehavior)
}

func matchArgumentAtCommon(actualType types.Type, argPos int, nullability NullabilityHandling, funcDefArgList []types.FuncDefArgType, variadicBehavior *VariadicBehavior) (bool, error) {
// check if argument out of range
if variadicBehavior == nil && argPos >= len(funcDefArgList) {
return false, nil
} else if variadicBehavior != nil && !variadicBehavior.IsValidArgumentCount(argPos+1) {
// this argument position can't be more than the max allowed by the variadic behavior
return false, nil
}

if HasSyncParams(funcDefArgList) {
return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented)
}
// if argPos is >= len(funcDefArgList) than last funcDefArg type should be considered for type match
// already checked for parameter in range above (considering variadic) so no need to check again for variadic
var funcDefArg types.FuncDefArgType
if argPos < len(funcDefArgList) {
funcDefArg = funcDefArgList[argPos]
} else {
funcDefArg = funcDefArgList[len(funcDefArgList)-1]
}
switch nullability {
case DiscreteNullability:
return funcDefArg.MatchWithNullability(actualType), nil
case MirrorNullability, DeclaredOutputNullability:
return funcDefArg.MatchWithoutNullability(actualType), nil
}
// unreachable case
return false, fmt.Errorf("invalid nullability type: %s", nullability)
}

func validateVariadicBehaviorForMatch(variadicBehavior *VariadicBehavior, actualTypes []types.Type) bool {
if !variadicBehavior.IsValidArgumentCount(len(actualTypes)) {
return false
}
// verify consistency of variadic behavior
if variadicBehavior.ParameterConsistency == ConsistentParams {
// all concrete types must be equal for all variable arguments
firstVariadicArgIdx := variadicBehavior.Min - 1
for i := firstVariadicArgIdx; i < len(actualTypes)-1; i++ {
if !actualTypes[i].Equals(actualTypes[i+1]) {
return false
}
}
}
return true
}

func getFuncDefFromArgList(paramTypeList ArgumentList) ([]types.FuncDefArgType, error) {
var out []types.FuncDefArgType
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err
}
out = append(out, funcDefArgType)
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
default:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
}
}
return out, nil
}

func parseFuncName(compoundName string) (name string, args ArgumentList) {
name, argsStr, _ := strings.Cut(compoundName, ":")
if len(argsStr) == 0 {
Expand Down Expand Up @@ -160,6 +276,14 @@ func (s *ScalarFunctionVariant) ID() ID {
return ID{URI: s.uri, Name: s.CompoundName()}
}

func (s *ScalarFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)
}

func (s *ScalarFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}

// NewAggFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -268,6 +392,12 @@ func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error)
}
func (s *AggregateFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *AggregateFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *AggregateFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *AggregateFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}

type WindowFunctionVariant struct {
name string
Expand Down Expand Up @@ -376,6 +506,12 @@ func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType }
func (s *WindowFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *WindowFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
Expand Down
Loading

0 comments on commit 1a800d9

Please sign in to comment.