From b5ad5b5cb61e5594dfdd770f38c53901c60821cb Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 18 Dec 2024 16:36:10 +1300 Subject: [PATCH] Support function parameter types This puts in place support for function parameter types, including check them at certain position. --- pkg/corset/binding.go | 65 ++++++++++++++++++------ pkg/corset/expression.go | 36 +++++++++----- pkg/corset/intrinsics.go | 17 +++++++ pkg/corset/parser.go | 34 ++++++++++--- pkg/corset/resolver.go | 85 +++++++++++++++++++++++--------- pkg/corset/translator.go | 8 +-- pkg/corset/type.go | 22 +++++++++ pkg/schema/type.go | 18 +++++++ pkg/test/invalid_corset_test.go | 5 ++ pkg/test/valid_corset_test.go | 4 ++ testdata/purefun_05.accepts | 48 ++++++++++++++++++ testdata/purefun_05.lisp | 5 ++ testdata/purefun_invalid_09.lisp | 5 ++ 13 files changed, 290 insertions(+), 62 deletions(-) create mode 100644 testdata/purefun_05.accepts create mode 100644 testdata/purefun_05.lisp create mode 100644 testdata/purefun_invalid_09.lisp diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go index 100a362..1cf53bb 100644 --- a/pkg/corset/binding.go +++ b/pkg/corset/binding.go @@ -36,12 +36,49 @@ type FunctionBinding interface { // they can accept any number of arguments. In contrast, a user-defined // function may only accept a specific number of arguments, etc. HasArity(uint) bool - // Apply a set of concreate arguments to this function. This substitutes - // them through the body of the function producing a single expression. - Apply([]Expr) Expr - // Get the declared return type of this function, or nil if no return type - // was declared. - ReturnType() Type + // Select the best fit signature based on the available parameter types. + // Observe that, for valid arities, this always returns a signature. + // However, that signature may not actually accept the provided parameters + // (in which case, an error should be reported). Furthermore, if no + // appropriate signature exists then this will return nil. + Select([]Type) *FunctionSignature +} + +// FunctionSignature embodies a concrete function instance. It is necessary to +// separate bindings from signatures because, in corset, function overloading is +// supported. That is, we can have different definitions for a function of the +// same name and arity. The appropriate definition is then selected for the +// given parameter types. +type FunctionSignature struct { + // Parameter types for this function + parameters []Type + // Return type for this function + ret Type + // Body of this function + body Expr +} + +// Return the (optional) return type for this signature. If no declared return +// type is given, then the intention is that it be inferred from the body. +func (p *FunctionSignature) Return() Type { + return p.ret +} + +// Parameter returns the given parameter in this signature. +func (p *FunctionSignature) Parameter(index uint) Type { + return p.parameters[index] +} + +// Apply a set of concreate arguments to this function. This substitutes +// them through the body of the function producing a single expression. +func (p *FunctionSignature) Apply(args []Expr) Expr { + mapping := make(map[uint]Expr) + // Setup the mapping + for i, e := range args { + mapping[uint(i)] = e + } + // Substitute through + return p.body.Substitute(mapping) } // ============================================================================ @@ -231,13 +268,11 @@ func (p *DefunBinding) Finalise(bodyType Type) { p.bodyType = bodyType } -// Apply a given set of arguments to this function binding. -func (p *DefunBinding) Apply(args []Expr) Expr { - mapping := make(map[uint]Expr) - // Setup the mapping - for i, e := range args { - mapping[uint(i)] = e - } - // Substitute through - return p.body.Substitute(mapping) +// Select the best fit signature based on the available parameter types. +// Observe that, for valid arities, this always returns a signature. +// However, that signature may not actually accept the provided parameters +// (in which case, an error should be reported). Furthermore, if no +// appropriate signature exists then this will return nil. +func (p *DefunBinding) Select(args []Type) *FunctionSignature { + return &FunctionSignature{p.paramTypes, p.returnType, p.body} } diff --git a/pkg/corset/expression.go b/pkg/corset/expression.go index 4dabb46..7387aa4 100644 --- a/pkg/corset/expression.go +++ b/pkg/corset/expression.go @@ -513,23 +513,21 @@ func (e *If) Dependencies() []Symbol { // Invoke represents an attempt to invoke a given function. type Invoke struct { - fn *VariableAccess - args []Expr + fn *VariableAccess + signature *FunctionSignature + args []Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. func (e *Invoke) AsConstant() *big.Int { - if e.fn.binding == nil { + if e.signature == nil { panic("unresolved invocation") - } else if fn_binding, ok := e.fn.binding.(FunctionBinding); ok { - // Unroll body - body := fn_binding.Apply(e.args) - // Attempt to evaluate as constant - return body.AsConstant() } - // Just fail - return nil + // Unroll body + body := e.signature.Apply(e.args) + // Attempt to evaluate as constant + return body.AsConstant() } // Args returns the arguments provided by this invocation to the function being @@ -558,10 +556,15 @@ func (e *Invoke) Lisp() sexp.SExp { return ListOfExpressions(e.fn.Lisp(), e.args) } +// Finalise the signature for this invocation. +func (e *Invoke) Finalise(signature *FunctionSignature) { + e.signature = signature +} + // Substitute all variables (such as for function parameters) arising in // this expression. func (e *Invoke) Substitute(mapping map[uint]Expr) Expr { - return &Invoke{e.fn, SubstituteExpressions(e.args, mapping)} + return &Invoke{e.fn, e.signature, SubstituteExpressions(e.args, mapping)} } // Dependencies needed to signal declaration. @@ -713,8 +716,9 @@ func (e *Normalise) Dependencies() []Symbol { // Reduce reduces (i.e. folds) a list using a given binary function. type Reduce struct { - fn *VariableAccess - arg Expr + fn *VariableAccess + signature *FunctionSignature + arg Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. @@ -751,10 +755,16 @@ func (e *Reduce) Lisp() sexp.SExp { func (e *Reduce) Substitute(mapping map[uint]Expr) Expr { return &Reduce{ e.fn, + e.signature, e.arg.Substitute(mapping), } } +// Finalise the signature for this reduction. +func (e *Reduce) Finalise(signature *FunctionSignature) { + e.signature = signature +} + // Dependencies needed to signal declaration. func (e *Reduce) Dependencies() []Symbol { deps := e.arg.Dependencies() diff --git a/pkg/corset/intrinsics.go b/pkg/corset/intrinsics.go index ceeb30d..e4fe1c8 100644 --- a/pkg/corset/intrinsics.go +++ b/pkg/corset/intrinsics.go @@ -67,6 +67,23 @@ func (p *IntrinsicDefinition) HasArity(arity uint) bool { return arity >= p.min_arity && arity <= p.max_arity } +// Select the best fit signature based on the available parameter types. +// Observe that, for valid arities, this always returns a signature. +// However, that signature may not actually accept the provided parameters +// (in which case, an error should be reported). Furthermore, if no +// appropriate signature exists then this will return nil. +func (p *IntrinsicDefinition) Select(args []Type) *FunctionSignature { + // construct the body + body := p.constructor(uint(len(args))) + types := make([]Type, len(args)) + // + for i := 0; i < len(types); i++ { + types[i] = NewFieldType() + } + // Allow return type to be inferred. + return &FunctionSignature{types, nil, body} +} + // Apply a given set of arguments to this function binding. func (p *IntrinsicDefinition) Apply(args []Expr) Expr { // First construct the body diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index 2b2c936..2466aa9 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -788,14 +788,32 @@ func (p *Parser) parseFunctionNameReturn(element sexp.SExp) (string, Type, bool, } func (p *Parser) parseFunctionParameter(element sexp.SExp) (*DefParameter, []SyntaxError) { + list := element.AsList() + // if isIdentifier(element) { binding := NewLocalVariableBinding(element.AsSymbol().Value, NewFieldType()) return &DefParameter{binding}, nil + } else if list == nil || list.Len() != 2 || !isIdentifier(list.Get(0)) { + // Construct error message (for now) + err := p.translator.SyntaxError(element, "malformed parameter declaration") + // + return nil, []SyntaxError{*err} + } + // Parse the type + datatype, prove, err := p.parseType(list.Get(1)) + // + if err != nil { + return nil, []SyntaxError{*err} + } else if prove { + // Parameters cannot be marked @prove + err := p.translator.SyntaxError(element, "malformed parameter declaration") + // + return nil, []SyntaxError{*err} } - // Construct error message (for now) - err := p.translator.SyntaxError(element, "malformed parameter declaration") + // Done + binding := NewLocalVariableBinding(list.Get(0).AsSymbol().Value, datatype) // - return nil, []SyntaxError{*err} + return &DefParameter{binding}, nil } // Parse a range declaration @@ -979,8 +997,10 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] { if len(errors) > 0 { return nil, errors } - // Construct binding - binding := NewLocalVariableBinding(indexVar.Value, nil) + // Construct binding. At this stage, its unclear what the best type to + // use for the index variable is here. Potentially, it could be refined + // based on the range of actual values, etc. + binding := NewLocalVariableBinding(indexVar.Value, NewFieldType()) // return &For{binding, start, end, body}, nil } @@ -1050,7 +1070,7 @@ func reduceParserRule(p *Parser) sexp.ListRule[Expr] { varaccess := &VariableAccess{nil, name.Value, true, nil} p.mapSourceNode(name, varaccess) // - return &Reduce{varaccess, body}, nil + return &Reduce{varaccess, nil, body}, nil } } @@ -1169,7 +1189,7 @@ func invokeParserRule(p *Parser) sexp.ListRule[Expr] { // p.mapSourceNode(list.Get(0), varaccess) // - return &Invoke{varaccess, args}, nil + return &Invoke{varaccess, nil, args}, nil } } diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index ab1054c..54ce278 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -600,42 +600,66 @@ func (r *resolver) finaliseIfInModule(scope LocalScope, expr *If) (Type, []Synta // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type, []SyntaxError) { + var ( + errors []SyntaxError + argTypes []Type + ) // Resolve arguments - if _, errors := r.finaliseExpressionsInModule(scope, expr.Args()); errors != nil { - return nil, errors - } + argTypes, errors = r.finaliseExpressionsInModule(scope, expr.Args()) // Lookup the corresponding function definition. if !expr.fn.IsResolved() && !scope.Bind(expr.fn) { - return nil, r.srcmap.SyntaxErrors(expr, "unknown function") + return nil, append(errors, *r.srcmap.SyntaxError(expr, "unknown function")) } // Following must be true if we get here. binding := expr.fn.binding.(FunctionBinding) - + // Check purity if scope.IsPure() && !binding.IsPure() { - return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context") - } else if !binding.HasArity(uint(len(expr.Args()))) { + errors = append(errors, *r.srcmap.SyntaxError(expr, "not permitted in pure context")) + } + // Check provide correct number of arguments + if !binding.HasArity(uint(len(expr.Args()))) { msg := fmt.Sprintf("incorrect number of arguments (found %d)", len(expr.Args())) - return nil, r.srcmap.SyntaxErrors(expr, msg) - } - // Check whether need to infer return type - if binding.ReturnType() != nil { - // no need, it was provided - return binding.ReturnType(), nil - } - // TODO: this is potentially expensive, and it would likely be good if we - // could avoid it. Realistically, this is just about determining the right - // type information. Potentially, we could adjust the local scope to - // provide the required type information. Or we could have a separate pass - // which just determines the type. - body := binding.Apply(expr.Args()) - // Dig out the type - return r.finaliseExpressionInModule(scope, body) + errors = append(errors, *r.srcmap.SyntaxError(expr, msg)) + } + // Select best overloaded signature + if signature := binding.Select(argTypes); signature != nil { + // Check arguments are accepted, based on their type. + for i := 0; i < len(argTypes); i++ { + expected := signature.Parameter(uint(i)) + actual := argTypes[i] + // subtype check + if !actual.SubtypeOf(expected) { + msg := fmt.Sprintf("expected type %s (found %s)", expected, actual) + errors = append(errors, *r.srcmap.SyntaxError(expr.args[i], msg)) + } + } + // + expr.Finalise(signature) + // + if len(errors) != 0 { + return nil, errors + } else if signature.Return() != nil { + // no need, it was provided + return signature.Return(), nil + } + // TODO: this is potentially expensive, and it would likely be good if we + // could avoid it. Realistically, this is just about determining the right + // type information. Potentially, we could adjust the local scope to + // provide the required type information. Or we could have a separate pass + // which just determines the type. + body := signature.Apply(expr.Args()) + // Dig out the type + return r.finaliseExpressionInModule(scope, body) + } + // ambiguous invocation + return nil, append(errors, *r.srcmap.SyntaxError(expr, "ambiguous invocation")) } // Resolve a specific invocation contained within some expression which, in // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. func (r *resolver) finaliseReduceInModule(scope LocalScope, expr *Reduce) (Type, []SyntaxError) { + var signature *FunctionSignature // Resolve arguments body_t, errors := r.finaliseExpressionInModule(scope, expr.arg) // Lookup the corresponding function definition. @@ -650,12 +674,27 @@ func (r *resolver) finaliseReduceInModule(scope LocalScope, expr *Reduce) (Type, } else if !binding.HasArity(2) { msg := "incorrect number of arguments (expected 2)" errors = append(errors, *r.srcmap.SyntaxError(expr, msg)) + } else if signature = binding.Select([]Type{body_t, body_t}); signature == nil { + msg := "ambiguous reduction" + errors = append(errors, *r.srcmap.SyntaxError(expr, msg)) + } + // Check left parameter type + if !body_t.SubtypeOf(signature.Parameter(0)) { + msg := fmt.Sprintf("expected type %s (found %s)", signature.Parameter(0), body_t) + errors = append(errors, *r.srcmap.SyntaxError(expr.arg, msg)) + } + // Check right parameter type + if !body_t.SubtypeOf(signature.Parameter(1)) { + msg := fmt.Sprintf("expected type %s (found %s)", signature.Parameter(1), body_t) + errors = append(errors, *r.srcmap.SyntaxError(expr.arg, msg)) } } // Error check if len(errors) > 0 { return nil, errors } + // Lock in signature + expr.Finalise(signature) // return body_t, nil } @@ -691,7 +730,7 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope, // Constant return binding.datatype, nil } else if binding, ok := expr.Binding().(*LocalVariableBinding); ok { - // Parameter + // Parameter or other local variable return binding.datatype, nil } else if _, ok := expr.Binding().(FunctionBinding); ok { // Function doesn't makes sense here. diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index cb9101f..68587c7 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -587,8 +587,8 @@ func (t *translator) translateForInModule(expr *For, module string, shift int) ( } func (t *translator) translateInvokeInModule(expr *Invoke, module string, shift int) (hir.Expr, []SyntaxError) { - if binding, ok := expr.fn.Binding().(FunctionBinding); ok { - body := binding.Apply(expr.Args()) + if expr.signature != nil { + body := expr.signature.Apply(expr.Args()) return t.translateExpressionInModule(body, module, shift) } // @@ -598,13 +598,13 @@ func (t *translator) translateInvokeInModule(expr *Invoke, module string, shift func (t *translator) translateReduceInModule(expr *Reduce, module string, shift int) (hir.Expr, []SyntaxError) { if list, ok := expr.arg.(*List); !ok { return nil, t.srcmap.SyntaxErrors(expr.arg, "expected list") - } else if binding, ok := expr.fn.Binding().(FunctionBinding); !ok { + } else if sig := expr.signature; sig == nil { return nil, t.srcmap.SyntaxErrors(expr.arg, "unbound function") } else { reduction := list.Args[0] // Build reduction for i := 1; i < len(list.Args); i++ { - reduction = binding.Apply([]Expr{reduction, list.Args[i]}) + reduction = sig.Apply([]Expr{reduction, list.Args[i]}) } // Translate reduction return t.translateExpressionInModule(reduction, module, shift) diff --git a/pkg/corset/type.go b/pkg/corset/type.go index 9a7af8b..ebb1227 100644 --- a/pkg/corset/type.go +++ b/pkg/corset/type.go @@ -25,6 +25,9 @@ type Type interface { // will panic if the type has already been given loobean semantics. WithBooleanSemantics() Type + // SubtypeOf determines whether or not this type is a subtype of another. + SubtypeOf(Type) bool + // Access an underlying representation of this type (should one exist). If // this doesn't exist, then nil is returned. AsUnderlying() sc.Type @@ -176,6 +179,16 @@ func (p *NativeType) AsUnderlying() sc.Type { return p.datatype } +// SubtypeOf determines whether or not this type is a subtype of another. +func (p *NativeType) SubtypeOf(other Type) bool { + if o, ok := other.(*NativeType); ok && p.datatype.SubtypeOf(o.datatype) { + // An interpreted type can flow into an uninterpreted type. + return (!o.loobean && !o.boolean) || p == o + } + // + return false +} + func (p *NativeType) String() string { if p.loobean { return fmt.Sprintf("%s@loob", p.datatype.String()) @@ -243,6 +256,15 @@ func (p *ArrayType) AsUnderlying() sc.Type { return nil } +// SubtypeOf determines whether or not this type is a subtype of another. +func (p *ArrayType) SubtypeOf(other Type) bool { + if o, ok := other.(*ArrayType); ok { + return p.element.SubtypeOf(o.element) + } + // + return false +} + func (p *ArrayType) String() string { return fmt.Sprintf("(%s)[%d]", p.element.String(), p.size) } diff --git a/pkg/schema/type.go b/pkg/schema/type.go index ce9b5f9..9f37cd7 100644 --- a/pkg/schema/type.go +++ b/pkg/schema/type.go @@ -23,6 +23,8 @@ type Type interface { ByteWidth() uint // Return the minimum number of bits required represent any element of this type. BitWidth() uint + // Check whether subtypes another + SubtypeOf(Type) bool // Produce a string representation of this type. String() string } @@ -98,6 +100,17 @@ func (p *UintType) Bound() fr.Element { return p.bound } +// SubtypeOf checks whether this subtypes another +func (p *UintType) SubtypeOf(other Type) bool { + if other.AsField() != nil { + return true + } else if o, ok := other.(*UintType); ok { + return p.bound == o.bound + } + + return false +} + func (p *UintType) String() string { return fmt.Sprintf("u%d", p.nbits) } @@ -130,6 +143,11 @@ func (p *FieldType) BitWidth() uint { return p.ByteWidth() * 8 } +// SubtypeOf checks whether this subtypes another +func (p *FieldType) SubtypeOf(other Type) bool { + return other.AsField() != nil +} + // Accept determines whether a given value is an element of this type. In // fact, all field elements are members of this type. func (p *FieldType) Accept(val fr.Element) bool { diff --git a/pkg/test/invalid_corset_test.go b/pkg/test/invalid_corset_test.go index aede841..f72265c 100644 --- a/pkg/test/invalid_corset_test.go +++ b/pkg/test/invalid_corset_test.go @@ -479,6 +479,11 @@ func Test_Invalid_PureFun_08(t *testing.T) { CheckInvalid(t, "purefun_invalid_08") } +func Test_Invalid_PureFun_09(t *testing.T) { + // tricky one + CheckInvalid(t, "purefun_invalid_09") +} + // =================================================================== // For Loops // =================================================================== diff --git a/pkg/test/valid_corset_test.go b/pkg/test/valid_corset_test.go index 0a39141..406f222 100644 --- a/pkg/test/valid_corset_test.go +++ b/pkg/test/valid_corset_test.go @@ -628,6 +628,10 @@ func Test_PureFun_04(t *testing.T) { Check(t, false, "purefun_04") } +func Test_PureFun_05(t *testing.T) { + Check(t, false, "purefun_05") +} + // =================================================================== // For Loops // =================================================================== diff --git a/testdata/purefun_05.accepts b/testdata/purefun_05.accepts new file mode 100644 index 0000000..ff6cef4 --- /dev/null +++ b/testdata/purefun_05.accepts @@ -0,0 +1,48 @@ +{ "X": [], "Y": [] } +;; +{ "X": [0], "Y": [0] } +{ "X": [1], "Y": [0] } +{ "X": [1], "Y": [1] } +{ "X": [1], "Y": [2] } +{ "X": [2], "Y": [2] } +{ "X": [3], "Y": [3] } +;; +{ "X": [0,0], "Y": [0,0] } +{ "X": [0,1], "Y": [0,0] } +{ "X": [0,1], "Y": [0,1] } +{ "X": [0,1], "Y": [0,2] } +{ "X": [0,2], "Y": [0,2] } +{ "X": [0,3], "Y": [0,3] } +;; +{ "X": [1,0], "Y": [0,0] } +{ "X": [1,1], "Y": [0,0] } +{ "X": [1,1], "Y": [0,1] } +{ "X": [1,1], "Y": [0,2] } +{ "X": [1,2], "Y": [0,2] } +{ "X": [1,3], "Y": [0,3] } +{ "X": [1,0], "Y": [1,0] } +{ "X": [1,1], "Y": [1,0] } +{ "X": [1,1], "Y": [1,1] } +{ "X": [1,1], "Y": [1,2] } +{ "X": [1,2], "Y": [1,2] } +{ "X": [1,3], "Y": [1,3] } +{ "X": [1,0], "Y": [2,0] } +{ "X": [1,1], "Y": [2,0] } +{ "X": [1,1], "Y": [2,1] } +{ "X": [1,1], "Y": [2,2] } +{ "X": [1,2], "Y": [2,2] } +{ "X": [1,3], "Y": [2,3] } +;; +{ "X": [2,0], "Y": [2,0] } +{ "X": [2,1], "Y": [2,0] } +{ "X": [2,1], "Y": [2,1] } +{ "X": [2,1], "Y": [2,2] } +{ "X": [2,2], "Y": [2,2] } +{ "X": [2,3], "Y": [2,3] } +;; +{ "X": [3,0], "Y": [3,0] } +{ "X": [3,1], "Y": [3,0] } +{ "X": [3,1], "Y": [3,1] } +{ "X": [3,1], "Y": [3,2] } +{ "X": [3,2], "Y": [3,2] } +{ "X": [3,3], "Y": [3,3] } diff --git a/testdata/purefun_05.lisp b/testdata/purefun_05.lisp new file mode 100644 index 0000000..fb7784d --- /dev/null +++ b/testdata/purefun_05.lisp @@ -0,0 +1,5 @@ +(defpurefun ((eq :binary@loob) (x :binary) (y :binary)) (^ (- x y) 2)) +;; +(defcolumns (X :binary@loob) (Y :binary)) +;; X == 1 || X == Y +(defconstraint c1 () (* (- X 1) (eq X Y))) diff --git a/testdata/purefun_invalid_09.lisp b/testdata/purefun_invalid_09.lisp new file mode 100644 index 0000000..db54715 --- /dev/null +++ b/testdata/purefun_invalid_09.lisp @@ -0,0 +1,5 @@ +(defpurefun ((eq :binary@loob :force) (x :binary) (y :binary)) (^ (- x y) 2)) +;; +(defcolumns (X :binary@loob) Y (Z :i16)) +(defconstraint c1 () (* (- X 1) (eq X Y))) +(defconstraint c2 () (* (- X 1) (eq X Z)))