Skip to content

Commit

Permalink
Support function parameter types
Browse files Browse the repository at this point in the history
This puts in place support for function parameter types, including check
them at certain position.
  • Loading branch information
DavePearce committed Dec 18, 2024
1 parent 051971a commit b5ad5b5
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 62 deletions.
65 changes: 50 additions & 15 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,49 @@ type FunctionBinding interface {
// they can accept any number of arguments. In contrast, a user-defined
// function may only accept a specific number of arguments, etc.
HasArity(uint) bool
// Apply a set of concreate arguments to this function. This substitutes
// them through the body of the function producing a single expression.
Apply([]Expr) Expr
// Get the declared return type of this function, or nil if no return type
// was declared.
ReturnType() Type
// Select the best fit signature based on the available parameter types.
// Observe that, for valid arities, this always returns a signature.
// However, that signature may not actually accept the provided parameters
// (in which case, an error should be reported). Furthermore, if no
// appropriate signature exists then this will return nil.
Select([]Type) *FunctionSignature
}

// FunctionSignature embodies a concrete function instance. It is necessary to
// separate bindings from signatures because, in corset, function overloading is
// supported. That is, we can have different definitions for a function of the
// same name and arity. The appropriate definition is then selected for the
// given parameter types.
type FunctionSignature struct {
// Parameter types for this function
parameters []Type
// Return type for this function
ret Type
// Body of this function
body Expr
}

// Return the (optional) return type for this signature. If no declared return
// type is given, then the intention is that it be inferred from the body.
func (p *FunctionSignature) Return() Type {
return p.ret
}

// Parameter returns the given parameter in this signature.
func (p *FunctionSignature) Parameter(index uint) Type {
return p.parameters[index]
}

// Apply a set of concreate arguments to this function. This substitutes
// them through the body of the function producing a single expression.
func (p *FunctionSignature) Apply(args []Expr) Expr {
mapping := make(map[uint]Expr)
// Setup the mapping
for i, e := range args {
mapping[uint(i)] = e
}
// Substitute through
return p.body.Substitute(mapping)
}

// ============================================================================
Expand Down Expand Up @@ -231,13 +268,11 @@ func (p *DefunBinding) Finalise(bodyType Type) {
p.bodyType = bodyType
}

// Apply a given set of arguments to this function binding.
func (p *DefunBinding) Apply(args []Expr) Expr {
mapping := make(map[uint]Expr)
// Setup the mapping
for i, e := range args {
mapping[uint(i)] = e
}
// Substitute through
return p.body.Substitute(mapping)
// Select the best fit signature based on the available parameter types.
// Observe that, for valid arities, this always returns a signature.
// However, that signature may not actually accept the provided parameters
// (in which case, an error should be reported). Furthermore, if no
// appropriate signature exists then this will return nil.
func (p *DefunBinding) Select(args []Type) *FunctionSignature {
return &FunctionSignature{p.paramTypes, p.returnType, p.body}
}
36 changes: 23 additions & 13 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,23 +513,21 @@ func (e *If) Dependencies() []Symbol {

// Invoke represents an attempt to invoke a given function.
type Invoke struct {
fn *VariableAccess
args []Expr
fn *VariableAccess
signature *FunctionSignature
args []Expr
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Invoke) AsConstant() *big.Int {
if e.fn.binding == nil {
if e.signature == nil {
panic("unresolved invocation")
} else if fn_binding, ok := e.fn.binding.(FunctionBinding); ok {
// Unroll body
body := fn_binding.Apply(e.args)
// Attempt to evaluate as constant
return body.AsConstant()
}
// Just fail
return nil
// Unroll body
body := e.signature.Apply(e.args)
// Attempt to evaluate as constant
return body.AsConstant()
}

// Args returns the arguments provided by this invocation to the function being
Expand Down Expand Up @@ -558,10 +556,15 @@ func (e *Invoke) Lisp() sexp.SExp {
return ListOfExpressions(e.fn.Lisp(), e.args)
}

// Finalise the signature for this invocation.
func (e *Invoke) Finalise(signature *FunctionSignature) {
e.signature = signature
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Invoke) Substitute(mapping map[uint]Expr) Expr {
return &Invoke{e.fn, SubstituteExpressions(e.args, mapping)}
return &Invoke{e.fn, e.signature, SubstituteExpressions(e.args, mapping)}
}

// Dependencies needed to signal declaration.
Expand Down Expand Up @@ -713,8 +716,9 @@ func (e *Normalise) Dependencies() []Symbol {

// Reduce reduces (i.e. folds) a list using a given binary function.
type Reduce struct {
fn *VariableAccess
arg Expr
fn *VariableAccess
signature *FunctionSignature
arg Expr
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
Expand Down Expand Up @@ -751,10 +755,16 @@ func (e *Reduce) Lisp() sexp.SExp {
func (e *Reduce) Substitute(mapping map[uint]Expr) Expr {
return &Reduce{
e.fn,
e.signature,
e.arg.Substitute(mapping),
}
}

// Finalise the signature for this reduction.
func (e *Reduce) Finalise(signature *FunctionSignature) {
e.signature = signature
}

// Dependencies needed to signal declaration.
func (e *Reduce) Dependencies() []Symbol {
deps := e.arg.Dependencies()
Expand Down
17 changes: 17 additions & 0 deletions pkg/corset/intrinsics.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ func (p *IntrinsicDefinition) HasArity(arity uint) bool {
return arity >= p.min_arity && arity <= p.max_arity
}

// Select the best fit signature based on the available parameter types.
// Observe that, for valid arities, this always returns a signature.
// However, that signature may not actually accept the provided parameters
// (in which case, an error should be reported). Furthermore, if no
// appropriate signature exists then this will return nil.
func (p *IntrinsicDefinition) Select(args []Type) *FunctionSignature {
// construct the body
body := p.constructor(uint(len(args)))
types := make([]Type, len(args))
//
for i := 0; i < len(types); i++ {
types[i] = NewFieldType()
}
// Allow return type to be inferred.
return &FunctionSignature{types, nil, body}
}

// Apply a given set of arguments to this function binding.
func (p *IntrinsicDefinition) Apply(args []Expr) Expr {
// First construct the body
Expand Down
34 changes: 27 additions & 7 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,14 +788,32 @@ func (p *Parser) parseFunctionNameReturn(element sexp.SExp) (string, Type, bool,
}

func (p *Parser) parseFunctionParameter(element sexp.SExp) (*DefParameter, []SyntaxError) {
list := element.AsList()
//
if isIdentifier(element) {
binding := NewLocalVariableBinding(element.AsSymbol().Value, NewFieldType())
return &DefParameter{binding}, nil
} else if list == nil || list.Len() != 2 || !isIdentifier(list.Get(0)) {
// Construct error message (for now)
err := p.translator.SyntaxError(element, "malformed parameter declaration")
//
return nil, []SyntaxError{*err}
}
// Parse the type
datatype, prove, err := p.parseType(list.Get(1))
//
if err != nil {
return nil, []SyntaxError{*err}
} else if prove {
// Parameters cannot be marked @prove
err := p.translator.SyntaxError(element, "malformed parameter declaration")
//
return nil, []SyntaxError{*err}
}
// Construct error message (for now)
err := p.translator.SyntaxError(element, "malformed parameter declaration")
// Done
binding := NewLocalVariableBinding(list.Get(0).AsSymbol().Value, datatype)
//
return nil, []SyntaxError{*err}
return &DefParameter{binding}, nil
}

// Parse a range declaration
Expand Down Expand Up @@ -979,8 +997,10 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
if len(errors) > 0 {
return nil, errors
}
// Construct binding
binding := NewLocalVariableBinding(indexVar.Value, nil)
// Construct binding. At this stage, its unclear what the best type to
// use for the index variable is here. Potentially, it could be refined
// based on the range of actual values, etc.
binding := NewLocalVariableBinding(indexVar.Value, NewFieldType())
//
return &For{binding, start, end, body}, nil
}
Expand Down Expand Up @@ -1050,7 +1070,7 @@ func reduceParserRule(p *Parser) sexp.ListRule[Expr] {
varaccess := &VariableAccess{nil, name.Value, true, nil}
p.mapSourceNode(name, varaccess)
//
return &Reduce{varaccess, body}, nil
return &Reduce{varaccess, nil, body}, nil
}
}

Expand Down Expand Up @@ -1169,7 +1189,7 @@ func invokeParserRule(p *Parser) sexp.ListRule[Expr] {
//
p.mapSourceNode(list.Get(0), varaccess)
//
return &Invoke{varaccess, args}, nil
return &Invoke{varaccess, nil, args}, nil
}
}

Expand Down
85 changes: 62 additions & 23 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,42 +600,66 @@ func (r *resolver) finaliseIfInModule(scope LocalScope, expr *If) (Type, []Synta
// turn, is contained within some module. Note, qualified accesses are only
// permitted in a global context.
func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type, []SyntaxError) {
var (
errors []SyntaxError
argTypes []Type
)
// Resolve arguments
if _, errors := r.finaliseExpressionsInModule(scope, expr.Args()); errors != nil {
return nil, errors
}
argTypes, errors = r.finaliseExpressionsInModule(scope, expr.Args())
// Lookup the corresponding function definition.
if !expr.fn.IsResolved() && !scope.Bind(expr.fn) {
return nil, r.srcmap.SyntaxErrors(expr, "unknown function")
return nil, append(errors, *r.srcmap.SyntaxError(expr, "unknown function"))
}
// Following must be true if we get here.
binding := expr.fn.binding.(FunctionBinding)

// Check purity
if scope.IsPure() && !binding.IsPure() {
return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context")
} else if !binding.HasArity(uint(len(expr.Args()))) {
errors = append(errors, *r.srcmap.SyntaxError(expr, "not permitted in pure context"))
}
// Check provide correct number of arguments
if !binding.HasArity(uint(len(expr.Args()))) {
msg := fmt.Sprintf("incorrect number of arguments (found %d)", len(expr.Args()))
return nil, r.srcmap.SyntaxErrors(expr, msg)
}
// Check whether need to infer return type
if binding.ReturnType() != nil {
// no need, it was provided
return binding.ReturnType(), nil
}
// TODO: this is potentially expensive, and it would likely be good if we
// could avoid it. Realistically, this is just about determining the right
// type information. Potentially, we could adjust the local scope to
// provide the required type information. Or we could have a separate pass
// which just determines the type.
body := binding.Apply(expr.Args())
// Dig out the type
return r.finaliseExpressionInModule(scope, body)
errors = append(errors, *r.srcmap.SyntaxError(expr, msg))
}
// Select best overloaded signature
if signature := binding.Select(argTypes); signature != nil {
// Check arguments are accepted, based on their type.
for i := 0; i < len(argTypes); i++ {
expected := signature.Parameter(uint(i))
actual := argTypes[i]
// subtype check
if !actual.SubtypeOf(expected) {
msg := fmt.Sprintf("expected type %s (found %s)", expected, actual)
errors = append(errors, *r.srcmap.SyntaxError(expr.args[i], msg))
}
}
//
expr.Finalise(signature)
//
if len(errors) != 0 {
return nil, errors
} else if signature.Return() != nil {
// no need, it was provided
return signature.Return(), nil
}
// TODO: this is potentially expensive, and it would likely be good if we
// could avoid it. Realistically, this is just about determining the right
// type information. Potentially, we could adjust the local scope to
// provide the required type information. Or we could have a separate pass
// which just determines the type.
body := signature.Apply(expr.Args())
// Dig out the type
return r.finaliseExpressionInModule(scope, body)
}
// ambiguous invocation
return nil, append(errors, *r.srcmap.SyntaxError(expr, "ambiguous invocation"))
}

// Resolve a specific invocation contained within some expression which, in
// turn, is contained within some module. Note, qualified accesses are only
// permitted in a global context.
func (r *resolver) finaliseReduceInModule(scope LocalScope, expr *Reduce) (Type, []SyntaxError) {
var signature *FunctionSignature
// Resolve arguments
body_t, errors := r.finaliseExpressionInModule(scope, expr.arg)
// Lookup the corresponding function definition.
Expand All @@ -650,12 +674,27 @@ func (r *resolver) finaliseReduceInModule(scope LocalScope, expr *Reduce) (Type,
} else if !binding.HasArity(2) {
msg := "incorrect number of arguments (expected 2)"
errors = append(errors, *r.srcmap.SyntaxError(expr, msg))
} else if signature = binding.Select([]Type{body_t, body_t}); signature == nil {
msg := "ambiguous reduction"
errors = append(errors, *r.srcmap.SyntaxError(expr, msg))
}
// Check left parameter type
if !body_t.SubtypeOf(signature.Parameter(0)) {
msg := fmt.Sprintf("expected type %s (found %s)", signature.Parameter(0), body_t)
errors = append(errors, *r.srcmap.SyntaxError(expr.arg, msg))
}
// Check right parameter type
if !body_t.SubtypeOf(signature.Parameter(1)) {
msg := fmt.Sprintf("expected type %s (found %s)", signature.Parameter(1), body_t)
errors = append(errors, *r.srcmap.SyntaxError(expr.arg, msg))
}
}
// Error check
if len(errors) > 0 {
return nil, errors
}
// Lock in signature
expr.Finalise(signature)
//
return body_t, nil
}
Expand Down Expand Up @@ -691,7 +730,7 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope,
// Constant
return binding.datatype, nil
} else if binding, ok := expr.Binding().(*LocalVariableBinding); ok {
// Parameter
// Parameter or other local variable
return binding.datatype, nil
} else if _, ok := expr.Binding().(FunctionBinding); ok {
// Function doesn't makes sense here.
Expand Down
8 changes: 4 additions & 4 deletions pkg/corset/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,8 @@ func (t *translator) translateForInModule(expr *For, module string, shift int) (
}

func (t *translator) translateInvokeInModule(expr *Invoke, module string, shift int) (hir.Expr, []SyntaxError) {
if binding, ok := expr.fn.Binding().(FunctionBinding); ok {
body := binding.Apply(expr.Args())
if expr.signature != nil {
body := expr.signature.Apply(expr.Args())
return t.translateExpressionInModule(body, module, shift)
}
//
Expand All @@ -598,13 +598,13 @@ func (t *translator) translateInvokeInModule(expr *Invoke, module string, shift
func (t *translator) translateReduceInModule(expr *Reduce, module string, shift int) (hir.Expr, []SyntaxError) {
if list, ok := expr.arg.(*List); !ok {
return nil, t.srcmap.SyntaxErrors(expr.arg, "expected list")
} else if binding, ok := expr.fn.Binding().(FunctionBinding); !ok {
} else if sig := expr.signature; sig == nil {
return nil, t.srcmap.SyntaxErrors(expr.arg, "unbound function")
} else {
reduction := list.Args[0]
// Build reduction
for i := 1; i < len(list.Args); i++ {
reduction = binding.Apply([]Expr{reduction, list.Args[i]})
reduction = sig.Apply([]Expr{reduction, list.Args[i]})
}
// Translate reduction
return t.translateExpressionInModule(reduction, module, shift)
Expand Down
Loading

0 comments on commit b5ad5b5

Please sign in to comment.