diff --git a/.eslintignore b/.eslintignore index ebdeff259..ba91e95c6 100644 --- a/.eslintignore +++ b/.eslintignore @@ -4,8 +4,8 @@ coverage/ **/output/ src/test/**/output/ src/func/funcfiftlib.js -src/grammar/grammar.ohm*.ts -src/grammar/grammar.ohm*.js +**/grammar.ohm*.ts +**/grammar.ohm*.js jest.setup.js jest.teardown.js /docs diff --git a/.gitignore b/.gitignore index 997174b31..3152343ad 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules coverage dist output/ -src/grammar/grammar.ohm-bundle.js -src/grammar/grammar.ohm-bundle.d.ts +**/grammar.ohm-bundle.js +**/grammar.ohm-bundle.d.ts src/func/funcfiftlib.wasm.js src/test/contracts/pretty-printer-output diff --git a/src/abi/AbiFunction.ts b/src/abi/AbiFunction.ts index d043df202..00faa5f73 100644 --- a/src/abi/AbiFunction.ts +++ b/src/abi/AbiFunction.ts @@ -1,7 +1,8 @@ -import { AstExpression, SrcInfo } from "../grammar/ast"; +import { AstExpression } from "../grammar/ast"; import { CompilerContext } from "../context"; import { WriterContext } from "../generator/Writer"; import { TypeRef } from "../types/types"; +import { SrcInfo } from "../grammar"; export type AbiFunction = { name: string; diff --git a/src/abi/map.ts b/src/abi/map.ts index 23370bcb1..0a8c0ea3a 100644 --- a/src/abi/map.ts +++ b/src/abi/map.ts @@ -1,5 +1,5 @@ import { CompilerContext } from "../context"; -import { SrcInfo } from "../grammar/grammar"; +import { SrcInfo } from "../grammar"; import { TypeRef } from "../types/types"; import { WriterContext } from "../generator/Writer"; import { ops } from "../generator/writers/ops"; diff --git a/src/check.ts b/src/check.ts index b26b2b912..5bc55e948 100644 --- a/src/check.ts +++ b/src/check.ts @@ -1,5 +1,7 @@ import { featureEnable } from "./config/features"; import { CompilerContext } from "./context"; +import { getAstFactory } from "./grammar/ast"; +import { getParser } from "./grammar"; import files from "./imports/stdlib"; import { createVirtualFileSystem, TactError, VirtualFileSystem } from "./main"; import { precompile } from "./pipeline/precompile"; @@ -35,10 +37,13 @@ export function check(args: { ctx = featureEnable(ctx, "masterchain"); // Enable masterchain flag to avoid masterchain-specific errors ctx = featureEnable(ctx, "external"); // Enable external messages flag to avoid external-specific errors + const ast = getAstFactory(); + const parser = getParser(ast); + // Execute check const items: CheckResultItem[] = []; try { - precompile(ctx, args.project, stdlib, args.entrypoint); + precompile(ctx, args.project, stdlib, args.entrypoint, parser, ast); } catch (e) { if (e instanceof TactError) { items.push({ diff --git a/src/constEval.ts b/src/constEval.ts index 16e509da2..630c901a7 100644 --- a/src/constEval.ts +++ b/src/constEval.ts @@ -2,19 +2,13 @@ import { CompilerContext } from "./context"; import { AstBinaryOperation, AstExpression, - SrcInfo, AstUnaryOperation, AstValue, isValue, } from "./grammar/ast"; import { TactConstEvalError } from "./errors"; import { Value } from "./types/types"; -import { - extractValue, - makeValueExpression, - makeUnaryExpression, - makeBinaryExpression, -} from "./optimizer/util"; +import { AstUtil, extractValue } from "./optimizer/util"; import { ExpressionTransformer } from "./optimizer/types"; import { StandardOptimizer } from "./optimizer/standardOptimizer"; import { @@ -25,6 +19,7 @@ import { evalUnaryOp, throwNonFatalErrorConstEval, } from "./interpreter"; +import { SrcInfo } from "./grammar"; // Utility Exception class to interrupt the execution // of functions that cannot evaluate a tree fully into a value. @@ -37,93 +32,194 @@ class PartiallyEvaluatedTree extends Error { } } -// The optimizer that applies the rewriting rules during partial evaluation. -// For the moment we use an optimizer that respects overflows. -const optimizer: ExpressionTransformer = new StandardOptimizer(); +export const getOptimizer = (util: AstUtil) => { + // The optimizer that applies the rewriting rules during partial evaluation. + // For the moment we use an optimizer that respects overflows. + const optimizer: ExpressionTransformer = new StandardOptimizer(util); -function partiallyEvalUnaryOp( - op: AstUnaryOperation, - operand: AstExpression, - source: SrcInfo, - ctx: CompilerContext, -): AstExpression { - if (operand.kind === "number" && op === "-") { - // emulating negative integer literals - return makeValueExpression(ensureInt(-operand.value, source)); - } + function partiallyEvalUnaryOp( + op: AstUnaryOperation, + operand: AstExpression, + source: SrcInfo, + ctx: CompilerContext, + ): AstExpression { + if (operand.kind === "number" && op === "-") { + // emulating negative integer literals + return util.makeValueExpression(ensureInt(-operand.value, source)); + } - const simplOperand = partiallyEvalExpression(operand, ctx); + const simplOperand = partiallyEvalExpression(operand, ctx); - if (isValue(simplOperand)) { - const valueOperand = extractValue(simplOperand as AstValue); - const result = evalUnaryOp(op, valueOperand, simplOperand.loc, source); - // Wrap the value into a Tree to continue simplifications - return makeValueExpression(result); - } else { - const newAst = makeUnaryExpression(op, simplOperand); - return optimizer.applyRules(newAst); + if (isValue(simplOperand)) { + const valueOperand = extractValue(simplOperand as AstValue); + const result = evalUnaryOp( + op, + valueOperand, + simplOperand.loc, + source, + ); + // Wrap the value into a Tree to continue simplifications + return util.makeValueExpression(result); + } else { + const newAst = util.makeUnaryExpression(op, simplOperand); + return optimizer.applyRules(newAst); + } } -} -function partiallyEvalBinaryOp( - op: AstBinaryOperation, - left: AstExpression, - right: AstExpression, - source: SrcInfo, - ctx: CompilerContext, -): AstExpression { - const leftOperand = partiallyEvalExpression(left, ctx); + function partiallyEvalBinaryOp( + op: AstBinaryOperation, + left: AstExpression, + right: AstExpression, + source: SrcInfo, + ctx: CompilerContext, + ): AstExpression { + const leftOperand = partiallyEvalExpression(left, ctx); - if (isValue(leftOperand)) { - // Because of short-circuiting, we must delay evaluation of the right operand - const valueLeftOperand = extractValue(leftOperand as AstValue); + if (isValue(leftOperand)) { + // Because of short-circuiting, we must delay evaluation of the right operand + const valueLeftOperand = extractValue(leftOperand as AstValue); - try { - const result = evalBinaryOp( + try { + const result = evalBinaryOp( + op, + valueLeftOperand, + // We delay the evaluation of the right operand inside a continuation + () => { + const rightOperand = partiallyEvalExpression( + right, + ctx, + ); + if (isValue(rightOperand)) { + // If the right operand reduces to a value, then we can let the function + // evalBinaryOp finish its normal execution by returning the value + // in the right operand. + return extractValue(rightOperand as AstValue); + } else { + // If the right operand does not reduce to a value, + // we interrupt the execution of the evalBinaryOp function + // by returning an exception with the partially evaluated right operand. + // The simplification rules will handle the partially evaluated tree in the catch + // of the try surrounding the evalBinaryOp function. + throw new PartiallyEvaluatedTree(rightOperand); + } + }, + leftOperand.loc, + right.loc, + source, + ); + + return util.makeValueExpression(result); + } catch (e) { + if (e instanceof PartiallyEvaluatedTree) { + // The right operand did not evaluate to a value. Hence, + // time to symbolically simplify the full tree. + const newAst = util.makeBinaryExpression( + op, + leftOperand, + e.tree, + ); + return optimizer.applyRules(newAst); + } else { + throw e; + } + } + } else { + // Since the left operand does not reduce to a value, no immediate short-circuiting will occur. + // Hence, we can partially evaluate the right operand and let the rules + // simplify the tree. + const rightOperand = partiallyEvalExpression(right, ctx); + const newAst = util.makeBinaryExpression( op, - valueLeftOperand, - // We delay the evaluation of the right operand inside a continuation - () => { - const rightOperand = partiallyEvalExpression(right, ctx); - if (isValue(rightOperand)) { - // If the right operand reduces to a value, then we can let the function - // evalBinaryOp finish its normal execution by returning the value - // in the right operand. - return extractValue(rightOperand as AstValue); - } else { - // If the right operand does not reduce to a value, - // we interrupt the execution of the evalBinaryOp function - // by returning an exception with the partially evaluated right operand. - // The simplification rules will handle the partially evaluated tree in the catch - // of the try surrounding the evalBinaryOp function. - throw new PartiallyEvaluatedTree(rightOperand); - } - }, - leftOperand.loc, - right.loc, - source, + leftOperand, + rightOperand, ); + return optimizer.applyRules(newAst); + } + } - return makeValueExpression(result); - } catch (e) { - if (e instanceof PartiallyEvaluatedTree) { - // The right operand did not evaluate to a value. Hence, - // time to symbolically simplify the full tree. - const newAst = makeBinaryExpression(op, leftOperand, e.tree); - return optimizer.applyRules(newAst); - } else { - throw e; - } + function partiallyEvalExpression( + ast: AstExpression, + ctx: CompilerContext, + interpreterConfig?: InterpreterConfig, + ): AstExpression { + const interpreter = new Interpreter(ctx, interpreterConfig); + switch (ast.kind) { + case "id": + try { + return util.makeValueExpression( + interpreter.interpretName(ast), + ); + } catch (e) { + if (e instanceof TactConstEvalError) { + if (!e.fatal) { + // If a non-fatal error occurs during lookup, just return the symbol + return ast; + } + } + throw e; + } + case "method_call": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return util.makeValueExpression( + interpreter.interpretMethodCall(ast), + ); + case "init_of": + throwNonFatalErrorConstEval( + "initOf is not supported at this moment", + ast.loc, + ); + break; + case "null": + return ast; + case "boolean": + return ast; + case "number": + return util.makeValueExpression( + interpreter.interpretNumber(ast), + ); + case "string": + return util.makeValueExpression( + interpreter.interpretString(ast), + ); + case "op_unary": + return partiallyEvalUnaryOp(ast.op, ast.operand, ast.loc, ctx); + case "op_binary": + return partiallyEvalBinaryOp( + ast.op, + ast.left, + ast.right, + ast.loc, + ctx, + ); + case "conditional": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return util.makeValueExpression( + interpreter.interpretConditional(ast), + ); + case "struct_instance": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return util.makeValueExpression( + interpreter.interpretStructInstance(ast), + ); + case "field_access": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return util.makeValueExpression( + interpreter.interpretFieldAccess(ast), + ); + case "static_call": + // Does not partially evaluate at the moment. Will attempt to fully evaluate + return util.makeValueExpression( + interpreter.interpretStaticCall(ast), + ); } - } else { - // Since the left operand does not reduce to a value, no immediate short-circuiting will occur. - // Hence, we can partially evaluate the right operand and let the rules - // simplify the tree. - const rightOperand = partiallyEvalExpression(right, ctx); - const newAst = makeBinaryExpression(op, leftOperand, rightOperand); - return optimizer.applyRules(newAst); } -} + + return { + partiallyEvalUnaryOp, + partiallyEvalBinaryOp, + partiallyEvalExpression, + }; +}; export function evalConstantExpression( ast: AstExpression, @@ -134,66 +230,3 @@ export function evalConstantExpression( const result = interpreter.interpretExpression(ast); return result; } - -export function partiallyEvalExpression( - ast: AstExpression, - ctx: CompilerContext, - interpreterConfig?: InterpreterConfig, -): AstExpression { - const interpreter = new Interpreter(ctx, interpreterConfig); - switch (ast.kind) { - case "id": - try { - return makeValueExpression(interpreter.interpretName(ast)); - } catch (e) { - if (e instanceof TactConstEvalError) { - if (!e.fatal) { - // If a non-fatal error occurs during lookup, just return the symbol - return ast; - } - } - throw e; - } - case "method_call": - // Does not partially evaluate at the moment. Will attempt to fully evaluate - return makeValueExpression(interpreter.interpretMethodCall(ast)); - case "init_of": - throwNonFatalErrorConstEval( - "initOf is not supported at this moment", - ast.loc, - ); - break; - case "null": - return ast; - case "boolean": - return ast; - case "number": - return makeValueExpression(interpreter.interpretNumber(ast)); - case "string": - return makeValueExpression(interpreter.interpretString(ast)); - case "op_unary": - return partiallyEvalUnaryOp(ast.op, ast.operand, ast.loc, ctx); - case "op_binary": - return partiallyEvalBinaryOp( - ast.op, - ast.left, - ast.right, - ast.loc, - ctx, - ); - case "conditional": - // Does not partially evaluate at the moment. Will attempt to fully evaluate - return makeValueExpression(interpreter.interpretConditional(ast)); - case "struct_instance": - // Does not partially evaluate at the moment. Will attempt to fully evaluate - return makeValueExpression( - interpreter.interpretStructInstance(ast), - ); - case "field_access": - // Does not partially evaluate at the moment. Will attempt to fully evaluate - return makeValueExpression(interpreter.interpretFieldAccess(ast)); - case "static_call": - // Does not partially evaluate at the moment. Will attempt to fully evaluate - return makeValueExpression(interpreter.interpretStaticCall(ast)); - } -} diff --git a/src/errors.ts b/src/errors.ts index 0dd9ce313..eee766790 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -1,8 +1,8 @@ import { MatchResult } from "ohm-js"; import path from "path"; import { cwd } from "process"; -import { AstFuncId, AstId, AstTypeId, SrcInfo } from "./grammar/ast"; -import { ItemOrigin } from "./grammar/grammar"; +import { AstFuncId, AstId, AstTypeId } from "./grammar/ast"; +import { ItemOrigin, SrcInfo } from "./grammar"; export class TactError extends Error { readonly loc?: SrcInfo; diff --git a/src/generator/writers/resolveFuncType.spec.ts b/src/generator/writers/resolveFuncType.spec.ts index 362f00667..e30743422 100644 --- a/src/generator/writers/resolveFuncType.spec.ts +++ b/src/generator/writers/resolveFuncType.spec.ts @@ -1,9 +1,10 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; +import { getAstFactory } from "../../grammar/ast"; import { resolveDescriptors } from "../../types/resolveDescriptors"; import { WriterContext } from "../Writer"; import { resolveFuncType } from "./resolveFuncType"; import { openContext } from "../../grammar/store"; import { CompilerContext } from "../../context"; +import { getParser } from "../../grammar"; const primitiveCode = ` primitive Int; @@ -45,17 +46,15 @@ contract Contract2 { `; describe("resolveFuncType", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - it("should process primitive types", () => { + const ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: primitiveCode, path: "", origin: "user" }], [], + getParser(ast), ); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); const wCtx = new WriterContext(ctx, "Contract1"); expect( resolveFuncType( @@ -117,12 +116,14 @@ describe("resolveFuncType", () => { }); it("should process contract and struct types", () => { + const ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: primitiveCode, path: "", origin: "user" }], [], + getParser(ast), ); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); const wCtx = new WriterContext(ctx, "Contract1"); expect( resolveFuncType( diff --git a/src/generator/writers/writeAccessors.ts b/src/generator/writers/writeAccessors.ts index df210313c..191235f59 100644 --- a/src/generator/writers/writeAccessors.ts +++ b/src/generator/writers/writeAccessors.ts @@ -1,6 +1,6 @@ import { contractErrors } from "../../abi/errors"; import { maxTupleSize } from "../../bindings/typescript/writeStruct"; -import { ItemOrigin } from "../../grammar/grammar"; +import { ItemOrigin } from "../../grammar"; import { getType } from "../../types/resolveDescriptors"; import { TypeDescription } from "../../types/types"; import { WriterContext } from "../Writer"; diff --git a/src/generator/writers/writeContract.ts b/src/generator/writers/writeContract.ts index 642682052..4cf60f77f 100644 --- a/src/generator/writers/writeContract.ts +++ b/src/generator/writers/writeContract.ts @@ -5,7 +5,7 @@ import { enabledIpfsAbiGetter, enabledMasterchain, } from "../../config/features"; -import { ItemOrigin } from "../../grammar/grammar"; +import { ItemOrigin } from "../../grammar"; import { InitDescription, TypeDescription } from "../../types/types"; import { WriterContext } from "../Writer"; import { funcIdOf, funcInitIdOf } from "./id"; diff --git a/src/generator/writers/writeExpression.spec.ts b/src/generator/writers/writeExpression.spec.ts index 19b9bd17f..4409dc8db 100644 --- a/src/generator/writers/writeExpression.spec.ts +++ b/src/generator/writers/writeExpression.spec.ts @@ -1,4 +1,3 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { getStaticFunction, resolveDescriptors, @@ -8,6 +7,8 @@ import { writeExpression } from "./writeExpression"; import { openContext } from "../../grammar/store"; import { resolveStatements } from "../../types/resolveStatements"; import { CompilerContext } from "../../context"; +import { getParser } from "../../grammar"; +import { getAstFactory } from "../../grammar/ast"; const code = ` @@ -68,16 +69,15 @@ const golden: string[] = [ ]; describe("writeExpression", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); it("should write expression", () => { + const ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: code, path: "", origin: "user" }], [], + getParser(ast), ); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); ctx = resolveStatements(ctx); const main = getStaticFunction(ctx, "main"); if (main.ast.kind !== "function_def") { diff --git a/src/generator/writers/writeSerialization.spec.ts b/src/generator/writers/writeSerialization.spec.ts index cafbd3307..316c88160 100644 --- a/src/generator/writers/writeSerialization.spec.ts +++ b/src/generator/writers/writeSerialization.spec.ts @@ -1,4 +1,3 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { CompilerContext } from "../../context"; import { getAllocation, @@ -14,6 +13,8 @@ import { writeParser, writeSerializer } from "./writeSerialization"; import { writeStdlib } from "./writeStdlib"; import { openContext } from "../../grammar/store"; import { writeAccessors } from "./writeAccessors"; +import { getParser } from "../../grammar"; +import { getAstFactory } from "../../grammar/ast"; const code = ` primitive Int; @@ -56,17 +57,16 @@ struct C { `; describe("writeSerialization", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); for (const s of ["A", "B", "C"]) { it("should write serializer for " + s, () => { + const ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code, path: "", origin: "user" }], [], + getParser(ast), ); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); ctx = resolveAllocations(ctx); const wCtx = new WriterContext(ctx, s); writeStdlib(wCtx); diff --git a/src/generator/writers/writeSerialization.ts b/src/generator/writers/writeSerialization.ts index aca186622..a51bfba09 100644 --- a/src/generator/writers/writeSerialization.ts +++ b/src/generator/writers/writeSerialization.ts @@ -1,6 +1,6 @@ import { contractErrors } from "../../abi/errors"; import { throwInternalCompilerError } from "../../errors"; -import { dummySrcInfo, ItemOrigin } from "../../grammar/grammar"; +import { dummySrcInfo, ItemOrigin } from "../../grammar"; import { AllocationCell, AllocationOperation } from "../../storage/operation"; import { StorageAllocation } from "../../storage/StorageAllocation"; import { getType } from "../../types/resolveDescriptors"; diff --git a/src/grammar/ast.ts b/src/grammar/ast.ts index 7d95641e8..f9d54810b 100644 --- a/src/grammar/ast.ts +++ b/src/grammar/ast.ts @@ -1,4 +1,5 @@ -import { dummySrcInfo, SrcInfo } from "./grammar"; +import { dummySrcInfo } from "./grammar"; +import { SrcInfo } from "./src-info"; export type AstModule = { kind: "module"; @@ -758,17 +759,22 @@ export function tryExtractPath(path: AstExpression): AstId[] | null { type DistributiveOmit = T extends any ? Omit : never; -let nextId = 1; -export function createAstNode(src: DistributiveOmit): AstNode { - return Object.freeze(Object.assign({ id: nextId++ }, src)); -} -export function cloneAstNode(src: T): T { - return { ...src, id: nextId++ }; -} -export function __DANGER_resetNodeId() { - nextId = 1; -} +export const getAstFactory = () => { + let nextId = 1; + function createNode(src: DistributiveOmit): AstNode { + return Object.freeze(Object.assign({ id: nextId++ }, src)); + } + function cloneNode(src: T): T { + return { ...src, id: nextId++ }; + } + return { + createNode, + cloneNode, + }; +}; + +export type FactoryAst = ReturnType; // Test equality of AstExpressions. export function eqExpressions( @@ -914,5 +920,3 @@ export function isValue(ast: AstExpression): boolean { return false; } } - -export { SrcInfo }; diff --git a/src/grammar/checkConstAttributes.ts b/src/grammar/checkConstAttributes.ts index 3346a9289..b7fdee28b 100644 --- a/src/grammar/checkConstAttributes.ts +++ b/src/grammar/checkConstAttributes.ts @@ -1,5 +1,6 @@ -import { AstConstantAttribute, SrcInfo } from "./ast"; +import { AstConstantAttribute } from "./ast"; import { throwSyntaxError } from "../errors"; +import { SrcInfo } from "./src-info"; export function checkConstAttributes( isAbstract: boolean, diff --git a/src/grammar/checkFunctionAttributes.ts b/src/grammar/checkFunctionAttributes.ts index dce102448..3b163fdfa 100644 --- a/src/grammar/checkFunctionAttributes.ts +++ b/src/grammar/checkFunctionAttributes.ts @@ -1,5 +1,6 @@ -import { AstFunctionAttribute, SrcInfo } from "./ast"; +import { AstFunctionAttribute } from "./ast"; import { throwCompilationError } from "../errors"; +import { SrcInfo } from "./src-info"; export function checkFunctionAttributes( isAbstract: boolean, diff --git a/src/grammar/checkVariableName.ts b/src/grammar/checkVariableName.ts index 3ba98ac10..faef19733 100644 --- a/src/grammar/checkVariableName.ts +++ b/src/grammar/checkVariableName.ts @@ -1,5 +1,5 @@ -import { SrcInfo } from "./ast"; import { throwCompilationError } from "../errors"; +import { SrcInfo } from "./src-info"; export function checkVariableName(name: string, loc: SrcInfo) { if (name.startsWith("__gen")) { diff --git a/src/grammar/clone.ts b/src/grammar/clone.ts index a63a8f25d..48e8a8a79 100644 --- a/src/grammar/clone.ts +++ b/src/grammar/clone.ts @@ -1,180 +1,187 @@ -import { AstNode, cloneAstNode } from "./ast"; +import { AstNode, FactoryAst } from "./ast"; import { throwInternalCompilerError } from "../errors"; -export function cloneNode(src: T): T { - if (src.kind === "boolean") { - return cloneAstNode(src); - } else if (src.kind === "id") { - return cloneAstNode(src); - } else if (src.kind === "null") { - return cloneAstNode(src); - } else if (src.kind === "number") { - return cloneAstNode(src); - } else if (src.kind === "string") { - return cloneAstNode(src); - } else if (src.kind === "statement_assign") { - return cloneAstNode({ - ...src, - path: cloneNode(src.path), - expression: cloneNode(src.expression), - }); - } else if (src.kind === "statement_augmentedassign") { - return cloneAstNode({ - ...src, - path: cloneNode(src.path), - expression: cloneNode(src.expression), - }); - } else if (src.kind === "statement_let") { - return cloneAstNode({ - ...src, - type: src.type ? cloneAstNode(src.type) : null, - expression: cloneNode(src.expression), - }); - } else if (src.kind === "statement_condition") { - return cloneAstNode({ - ...src, - condition: cloneNode(src.condition), - trueStatements: src.trueStatements.map(cloneNode), - falseStatements: src.falseStatements - ? src.falseStatements.map(cloneNode) - : null, - elseif: src.elseif ? cloneNode(src.elseif) : null, - }); - } else if (src.kind === "struct_field_initializer") { - return cloneAstNode({ - ...src, - initializer: cloneNode(src.initializer), - }); - } else if (src.kind === "statement_expression") { - return cloneAstNode({ - ...src, - expression: cloneNode(src.expression), - }); - } else if (src.kind === "op_binary") { - return cloneAstNode({ - ...src, - left: cloneNode(src.left), - right: cloneNode(src.right), - }); - } else if (src.kind === "op_unary") { - return cloneAstNode({ - ...src, - operand: cloneNode(src.operand), - }); - } else if (src.kind === "struct_instance") { - return cloneAstNode({ - ...src, - args: src.args.map(cloneNode), - }); - } else if (src.kind === "method_call") { - return cloneAstNode({ - ...src, - self: cloneNode(src.self), - args: src.args.map(cloneNode), - }); - } else if (src.kind === "field_access") { - return cloneAstNode({ - ...src, - aggregate: cloneNode(src.aggregate), - }); - } else if (src.kind === "static_call") { - return cloneAstNode({ - ...src, - args: src.args.map(cloneNode), - }); - } else if (src.kind === "conditional") { - return cloneAstNode({ - ...src, - condition: cloneNode(src.condition), - thenBranch: cloneNode(src.thenBranch), - elseBranch: cloneNode(src.elseBranch), - }); - } else if (src.kind === "statement_return") { - return cloneAstNode({ - ...src, - expression: src.expression ? cloneNode(src.expression) : null, - }); - } else if (src.kind === "statement_repeat") { - return cloneAstNode({ - ...src, - iterations: cloneNode(src.iterations), - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "statement_until") { - return cloneAstNode({ - ...src, - condition: cloneNode(src.condition), - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "statement_while") { - return cloneAstNode({ - ...src, - condition: cloneNode(src.condition), - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "statement_try") { - return cloneAstNode({ - ...src, - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "statement_try_catch") { - return cloneAstNode({ - ...src, - statements: src.statements.map(cloneNode), - catchStatements: src.catchStatements.map(cloneNode), - }); - } else if (src.kind === "statement_foreach") { - return cloneAstNode({ - ...src, - map: cloneNode(src.map), - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "function_def") { - return cloneAstNode({ - ...src, - return: src.return ? cloneAstNode(src.return) : null, - statements: src.statements.map(cloneNode), - params: src.params.map(cloneNode), - }); - } else if (src.kind === "function_decl") { - return cloneAstNode({ - ...src, - return: src.return ? cloneAstNode(src.return) : null, - params: src.params.map(cloneNode), - }); - } else if (src.kind === "native_function_decl") { - return cloneAstNode({ - ...src, - return: src.return ? cloneAstNode(src.return) : null, - params: src.params.map(cloneNode), - }); - } else if (src.kind === "receiver") { - return cloneAstNode({ - ...src, - statements: src.statements.map(cloneNode), - }); - } else if (src.kind === "typed_parameter") { - return cloneAstNode({ - ...src, - type: cloneAstNode(src.type), - }); - } else if (src.kind === "init_of") { - return cloneAstNode({ - ...src, - args: src.args.map(cloneNode), - }); - } else if (src.kind === "constant_def") { - return cloneAstNode({ - ...src, - type: cloneAstNode(src.type), - initializer: cloneNode(src.initializer), - }); - } else if (src.kind === "constant_decl") { - return cloneAstNode({ - ...src, - type: cloneAstNode(src.type), - }); - } +export function cloneNode( + src: T, + { cloneNode }: FactoryAst, +): T { + const recurse = (src: T): T => { + if (src.kind === "boolean") { + return cloneNode(src); + } else if (src.kind === "id") { + return cloneNode(src); + } else if (src.kind === "null") { + return cloneNode(src); + } else if (src.kind === "number") { + return cloneNode(src); + } else if (src.kind === "string") { + return cloneNode(src); + } else if (src.kind === "statement_assign") { + return cloneNode({ + ...src, + path: recurse(src.path), + expression: recurse(src.expression), + }); + } else if (src.kind === "statement_augmentedassign") { + return cloneNode({ + ...src, + path: recurse(src.path), + expression: recurse(src.expression), + }); + } else if (src.kind === "statement_let") { + return cloneNode({ + ...src, + type: src.type ? cloneNode(src.type) : null, + expression: recurse(src.expression), + }); + } else if (src.kind === "statement_condition") { + return cloneNode({ + ...src, + condition: recurse(src.condition), + trueStatements: src.trueStatements.map(recurse), + falseStatements: src.falseStatements + ? src.falseStatements.map(recurse) + : null, + elseif: src.elseif ? recurse(src.elseif) : null, + }); + } else if (src.kind === "struct_field_initializer") { + return cloneNode({ + ...src, + initializer: recurse(src.initializer), + }); + } else if (src.kind === "statement_expression") { + return cloneNode({ + ...src, + expression: recurse(src.expression), + }); + } else if (src.kind === "op_binary") { + return cloneNode({ + ...src, + left: recurse(src.left), + right: recurse(src.right), + }); + } else if (src.kind === "op_unary") { + return cloneNode({ + ...src, + operand: recurse(src.operand), + }); + } else if (src.kind === "struct_instance") { + return cloneNode({ + ...src, + args: src.args.map(recurse), + }); + } else if (src.kind === "method_call") { + return cloneNode({ + ...src, + self: recurse(src.self), + args: src.args.map(recurse), + }); + } else if (src.kind === "field_access") { + return cloneNode({ + ...src, + aggregate: recurse(src.aggregate), + }); + } else if (src.kind === "static_call") { + return cloneNode({ + ...src, + args: src.args.map(recurse), + }); + } else if (src.kind === "conditional") { + return cloneNode({ + ...src, + condition: recurse(src.condition), + thenBranch: recurse(src.thenBranch), + elseBranch: recurse(src.elseBranch), + }); + } else if (src.kind === "statement_return") { + return cloneNode({ + ...src, + expression: src.expression ? recurse(src.expression) : null, + }); + } else if (src.kind === "statement_repeat") { + return cloneNode({ + ...src, + iterations: recurse(src.iterations), + statements: src.statements.map(recurse), + }); + } else if (src.kind === "statement_until") { + return cloneNode({ + ...src, + condition: recurse(src.condition), + statements: src.statements.map(recurse), + }); + } else if (src.kind === "statement_while") { + return cloneNode({ + ...src, + condition: recurse(src.condition), + statements: src.statements.map(recurse), + }); + } else if (src.kind === "statement_try") { + return cloneNode({ + ...src, + statements: src.statements.map(recurse), + }); + } else if (src.kind === "statement_try_catch") { + return cloneNode({ + ...src, + statements: src.statements.map(recurse), + catchStatements: src.catchStatements.map(recurse), + }); + } else if (src.kind === "statement_foreach") { + return cloneNode({ + ...src, + map: recurse(src.map), + statements: src.statements.map(recurse), + }); + } else if (src.kind === "function_def") { + return cloneNode({ + ...src, + return: src.return ? cloneNode(src.return) : null, + statements: src.statements.map(recurse), + params: src.params.map(recurse), + }); + } else if (src.kind === "function_decl") { + return cloneNode({ + ...src, + return: src.return ? cloneNode(src.return) : null, + params: src.params.map(recurse), + }); + } else if (src.kind === "native_function_decl") { + return cloneNode({ + ...src, + return: src.return ? cloneNode(src.return) : null, + params: src.params.map(recurse), + }); + } else if (src.kind === "receiver") { + return cloneNode({ + ...src, + statements: src.statements.map(recurse), + }); + } else if (src.kind === "typed_parameter") { + return cloneNode({ + ...src, + type: cloneNode(src.type), + }); + } else if (src.kind === "init_of") { + return cloneNode({ + ...src, + args: src.args.map(recurse), + }); + } else if (src.kind === "constant_def") { + return cloneNode({ + ...src, + type: cloneNode(src.type), + initializer: recurse(src.initializer), + }); + } else if (src.kind === "constant_decl") { + return cloneNode({ + ...src, + type: cloneNode(src.type), + }); + } - throwInternalCompilerError(`Not implemented for ${src.kind}`); + throwInternalCompilerError(`Not implemented for ${src.kind}`); + }; + + return recurse(src); } diff --git a/src/grammar/grammar.spec.ts b/src/grammar/grammar.spec.ts index 31a5bdadf..7af9c64dc 100644 --- a/src/grammar/grammar.spec.ts +++ b/src/grammar/grammar.spec.ts @@ -1,6 +1,7 @@ -import { parse } from "./grammar"; -import { AstModule, SrcInfo, __DANGER_resetNodeId } from "./ast"; +import { AstModule, getAstFactory } from "./ast"; import { loadCases } from "../utils/loadCases"; +import { getParser } from "./grammar"; +import { SrcInfo } from "./src-info"; expect.addSnapshotSerializer({ test: (src) => src instanceof SrcInfo, @@ -8,13 +9,11 @@ expect.addSnapshotSerializer({ }); describe("grammar", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - // Test parsing of known Fift projects, wrapped in asm functions of Tact for (const r of loadCases(__dirname + "/test-asm/")) { it("should parse " + r.name, () => { + const ast = getAstFactory(); + const { parse } = getParser(ast); const parsed: AstModule | undefined = parse( r.code, "", @@ -28,12 +27,16 @@ describe("grammar", () => { for (const r of loadCases(__dirname + "/test/")) { it("should parse " + r.name, () => { + const ast = getAstFactory(); + const { parse } = getParser(ast); expect(parse(r.code, "", "user")).toMatchSnapshot(); }); } for (const r of loadCases(__dirname + "/test-failed/")) { it("should fail " + r.name, () => { + const ast = getAstFactory(); + const { parse } = getParser(ast); expect(() => parse(r.code, "", "user"), ).toThrowErrorMatchingSnapshot(); diff --git a/src/grammar/grammar.ts b/src/grammar/grammar.ts index 894047aef..069333659 100644 --- a/src/grammar/grammar.ts +++ b/src/grammar/grammar.ts @@ -1,11 +1,4 @@ -import { - Interval as RawInterval, - Node, - IterationNode, - NonterminalNode, - grammar, - Grammar, -} from "ohm-js"; +import { Node, IterationNode, NonterminalNode, grammar, Grammar } from "ohm-js"; import tactGrammar from "./grammar.ohm-bundle"; import { throwInternalCompilerError } from "../errors"; import { @@ -19,73 +12,60 @@ import { AstReceiverKind, AstString, AstType, - createAstNode, AstImport, AstConstantDef, AstNumberBase, AstId, + FactoryAst, } from "./ast"; import { throwParseError, throwSyntaxError } from "../errors"; import { checkVariableName } from "./checkVariableName"; import { checkFunctionAttributes } from "./checkFunctionAttributes"; import { checkConstAttributes } from "./checkConstAttributes"; +import { ItemOrigin, SrcInfo } from "./src-info"; -export type ItemOrigin = "stdlib" | "user"; - -let ctx: { origin: ItemOrigin } | null; +const DummyGrammar: Grammar = grammar("Dummy { DummyRule = any }"); +const DUMMY_INTERVAL = DummyGrammar.match("").getInterval(); +export const dummySrcInfo: SrcInfo = new SrcInfo(DUMMY_INTERVAL, null, "user"); -/** - * Information about source code location (file and interval within it) - * and the source code contents. - */ -export class SrcInfo { - readonly #interval: RawInterval; - readonly #file: string | null; - readonly #origin: ItemOrigin; +type Context = { + origin: ItemOrigin | null; + currentFile: string | null; + createNode: FactoryAst["createNode"] | null; +}; - constructor( - interval: RawInterval, - file: string | null, - origin: ItemOrigin, - ) { - this.#interval = interval; - this.#file = file; - this.#origin = origin; - } +const defaultContext: Context = Object.freeze({ + createNode: null, + currentFile: null, + origin: null, +}); - get file() { - return this.#file; - } +let context: Context = defaultContext; - get contents() { - return this.#interval.contents; +const withContext = (ctx: Context, callback: () => T): T => { + try { + context = ctx; + return callback(); + } finally { + context = defaultContext; } +}; - get interval() { - return this.#interval; +function createRef(s: Node): SrcInfo { + if (context.origin === null) { + throwInternalCompilerError("Parser context was not initialized"); } - get origin() { - return this.#origin; - } + return new SrcInfo(s.source, context.currentFile, context.origin); } -const DummyGrammar: Grammar = grammar("Dummy { DummyRule = any }"); -const DUMMY_INTERVAL = DummyGrammar.match("").getInterval(); -export const dummySrcInfo: SrcInfo = new SrcInfo(DUMMY_INTERVAL, null, "user"); - -let currentFile: string | null = null; - -function inFile(path: string, callback: () => T) { - currentFile = path; - const r = callback(); - currentFile = null; - return r; -} +const createNode: FactoryAst["createNode"] = (...args) => { + if (context.createNode === null) { + throwInternalCompilerError("Parser context was not initialized"); + } -function createRef(s: Node): SrcInfo { - return new SrcInfo(s.source, currentFile, ctx!.origin); -} + return context.createNode(...args); +}; // helper to unwrap optional grammar elements (marked with "?") // ohm-js represents those essentially as lists (IterationNodes) @@ -101,7 +81,7 @@ const semantics = tactGrammar.createSemantics(); semantics.addOperation("astOfModule", { Module(imports, items) { - return createAstNode({ + return createNode({ kind: "module", imports: imports.children.map((item) => item.astOfImport()), items: items.children.map((item) => item.astOfModuleItem()), @@ -118,7 +98,7 @@ semantics.addOperation("astOfImport", { createRef(path), ); } - return createAstNode({ + return createNode({ kind: "import", path: pathAST, loc: createRef(this), @@ -135,7 +115,7 @@ semantics.addOperation("astOfJustImports", { semantics.addOperation("astOfModuleItem", { PrimitiveTypeDecl(_primitive_kwd, typeId, _semicolon) { checkVariableName(typeId.sourceString, createRef(typeId)); - return createAstNode({ + return createNode({ kind: "primitive_type_decl", name: typeId.astOfType(), loc: createRef(this), @@ -155,7 +135,7 @@ semantics.addOperation("astOfModuleItem", { _semicolon, ) { checkVariableName(tactId.sourceString, createRef(tactId)); - return createAstNode({ + return createNode({ kind: "native_function_decl", attributes: funAttributes.children.map((a) => a.astOfFunctionAttributes(), @@ -169,7 +149,7 @@ semantics.addOperation("astOfModuleItem", { }, StructDecl_regular(_structKwd, typeId, _lbrace, fields, _rbrace) { checkVariableName(typeId.sourceString, createRef(typeId)); - return createAstNode({ + return createNode({ kind: "struct_decl", name: typeId.astOfType(), fields: fields.astsOfList(), @@ -187,7 +167,7 @@ semantics.addOperation("astOfModuleItem", { _rbrace, ) { checkVariableName(typeId.sourceString, createRef(typeId)); - return createAstNode({ + return createNode({ kind: "message_decl", name: typeId.astOfType(), fields: fields.astsOfList(), @@ -208,7 +188,7 @@ semantics.addOperation("astOfModuleItem", { _rbrace, ) { checkVariableName(contractId.sourceString, createRef(contractId)); - return createAstNode({ + return createNode({ kind: "contract", name: contractId.astOfExpression(), attributes: attributes.children.map((ca) => @@ -232,7 +212,7 @@ semantics.addOperation("astOfModuleItem", { _rbrace, ) { checkVariableName(traitId.sourceString, createRef(traitId)); - return createAstNode({ + return createNode({ kind: "trait", name: traitId.astOfExpression(), attributes: attributes.children.map((ca) => @@ -276,7 +256,7 @@ semantics.addOperation("astOfItem", { a.astOfConstAttribute(), ) as AstConstantAttribute[]; checkConstAttributes(false, attributes, createRef(this)); - return createAstNode({ + return createNode({ kind: "constant_def", name: constId.astOfExpression(), type: constType.astOfType(), @@ -297,7 +277,7 @@ semantics.addOperation("astOfItem", { a.astOfConstAttribute(), ) as AstConstantAttribute[]; checkConstAttributes(true, attributes, createRef(this)); - return createAstNode({ + return createNode({ kind: "constant_decl", name: constId.astOfExpression(), type: constType.astOfType(), @@ -324,7 +304,7 @@ semantics.addOperation("astOfItem", { ) as AstFunctionAttribute[]; checkVariableName(funId.sourceString, createRef(funId)); checkFunctionAttributes(false, attributes, createRef(this)); - return createAstNode({ + return createNode({ kind: "function_def", attributes, name: funId.astOfExpression(), @@ -356,7 +336,7 @@ semantics.addOperation("astOfItem", { ) as AstFunctionAttribute[]; checkVariableName(funId.sourceString, createRef(funId)); checkFunctionAttributes(false, attributes, createRef(this)); - return createAstNode({ + return createNode({ kind: "asm_function_def", shuffle, attributes, @@ -383,7 +363,7 @@ semantics.addOperation("astOfItem", { ) as AstFunctionAttribute[]; checkVariableName(funId.sourceString, createRef(funId)); checkFunctionAttributes(true, attributes, createRef(this)); - return createAstNode({ + return createNode({ kind: "function_decl", attributes, name: funId.astOfExpression(), @@ -393,7 +373,7 @@ semantics.addOperation("astOfItem", { }); }, ContractInit(_initKwd, initParameters, _lbrace, initBody, _rbrace) { - return createAstNode({ + return createNode({ kind: "contract_init", params: initParameters.astsOfList(), statements: initBody.children.map((s) => s.astOfStatement()), @@ -416,7 +396,7 @@ semantics.addOperation("astOfItem", { param: optParam.astOfDeclaration(), } : { kind: "internal-fallback" }; - return createAstNode({ + return createNode({ kind: "receiver", selector, statements: receiverBody.children.map((s) => s.astOfStatement()), @@ -432,7 +412,7 @@ semantics.addOperation("astOfItem", { receiverBody, _rbrace, ) { - return createAstNode({ + return createNode({ kind: "receiver", selector: { kind: "internal-comment", @@ -451,7 +431,7 @@ semantics.addOperation("astOfItem", { receiverBody, _rbrace, ) { - return createAstNode({ + return createNode({ kind: "receiver", selector: { kind: "bounce", param: parameter.astOfDeclaration() }, statements: receiverBody.children.map((s) => s.astOfStatement()), @@ -474,7 +454,7 @@ semantics.addOperation("astOfItem", { param: optParam.astOfDeclaration(), } : { kind: "external-fallback" }; - return createAstNode({ + return createNode({ kind: "receiver", selector, statements: receiverBody.children.map((s) => s.astOfStatement()), @@ -490,7 +470,7 @@ semantics.addOperation("astOfItem", { receiverBody, _rbrace, ) { - return createAstNode({ + return createNode({ kind: "receiver", selector: { kind: "external-comment", @@ -720,7 +700,7 @@ semantics.addOperation("astOfDeclaration", { _optEq, optInitializer, ) { - return createAstNode({ + return createNode({ kind: "field_decl", name: id.astOfExpression(), type: type.astOfType() as AstType, @@ -733,7 +713,7 @@ semantics.addOperation("astOfDeclaration", { }, Parameter(id, _colon, type) { checkVariableName(id.sourceString, createRef(id)); - return createAstNode({ + return createNode({ kind: "typed_parameter", name: id.astOfExpression(), type: type.astOfType(), @@ -741,7 +721,7 @@ semantics.addOperation("astOfDeclaration", { }); }, StructFieldInitializer_full(fieldId, _colon, initializer) { - return createAstNode({ + return createNode({ kind: "struct_field_initializer", field: fieldId.astOfExpression(), initializer: initializer.astOfExpression(), @@ -749,7 +729,7 @@ semantics.addOperation("astOfDeclaration", { }); }, StructFieldInitializer_punned(fieldId) { - return createAstNode({ + return createNode({ kind: "struct_field_initializer", field: fieldId.astOfExpression(), initializer: fieldId.astOfExpression(), @@ -773,7 +753,7 @@ semantics.addOperation("astOfStatement", { ) { checkVariableName(id.sourceString, createRef(id)); - return createAstNode({ + return createNode({ kind: "statement_let", name: id.astOfExpression(), type: unwrapOptNode(optType, (t) => t.astOfType()), @@ -782,7 +762,7 @@ semantics.addOperation("astOfStatement", { }); }, StatementReturn(_returnKwd, optExpression, _optSemicolonIfLastStmtInBlock) { - return createAstNode({ + return createNode({ kind: "statement_return", expression: unwrapOptNode(optExpression, (e) => e.astOfExpression(), @@ -791,7 +771,7 @@ semantics.addOperation("astOfStatement", { }); }, StatementExpression(expression, _optSemicolonIfLastStmtInBlock) { - return createAstNode({ + return createNode({ kind: "statement_expression", expression: expression.astOfExpression(), loc: createRef(this), @@ -804,7 +784,7 @@ semantics.addOperation("astOfStatement", { _optSemicolonIfLastStmtInBlock, ) { if (operator.sourceString === "=") { - return createAstNode({ + return createNode({ kind: "statement_assign", path: lvalue.astOfExpression(), expression: expression.astOfExpression(), @@ -854,7 +834,7 @@ semantics.addOperation("astOfStatement", { "Unreachable augmented assignment operator.", ); } - return createAstNode({ + return createNode({ kind: "statement_augmentedassign", path: lvalue.astOfExpression(), op, @@ -864,7 +844,7 @@ semantics.addOperation("astOfStatement", { } }, StatementCondition_noElse(_ifKwd, condition, _lbrace, thenBlock, _rbrace) { - return createAstNode({ + return createNode({ kind: "statement_condition", condition: condition.astOfExpression(), trueStatements: thenBlock.children.map((s) => s.astOfStatement()), @@ -884,7 +864,7 @@ semantics.addOperation("astOfStatement", { elseBlock, _rbraceElse, ) { - return createAstNode({ + return createNode({ kind: "statement_condition", condition: condition.astOfExpression(), trueStatements: thenBlock.children.map((s) => s.astOfStatement()), @@ -902,7 +882,7 @@ semantics.addOperation("astOfStatement", { _elseKwd, elseifClause, ) { - return createAstNode({ + return createNode({ kind: "statement_condition", condition: condition.astOfExpression(), trueStatements: thenBlock.children.map((s) => s.astOfStatement()), @@ -920,7 +900,7 @@ semantics.addOperation("astOfStatement", { loopBody, _rbrace, ) { - return createAstNode({ + return createNode({ kind: "statement_while", condition: condition.astOfExpression(), statements: loopBody.children.map((s) => s.astOfStatement()), @@ -936,7 +916,7 @@ semantics.addOperation("astOfStatement", { loopBody, _rbrace, ) { - return createAstNode({ + return createNode({ kind: "statement_repeat", iterations: iterations.astOfExpression(), statements: loopBody.children.map((s) => s.astOfStatement()), @@ -954,7 +934,7 @@ semantics.addOperation("astOfStatement", { _rparen, _optSemicolonIfLastStmtInBlock, ) { - return createAstNode({ + return createNode({ kind: "statement_until", condition: condition.astOfExpression(), statements: loopBody.children.map((s) => s.astOfStatement()), @@ -962,7 +942,7 @@ semantics.addOperation("astOfStatement", { }); }, StatementTry_noCatch(_tryKwd, _lbraceTry, tryBlock, _rbraceTry) { - return createAstNode({ + return createNode({ kind: "statement_try", statements: tryBlock.children.map((s) => s.astOfStatement()), loc: createRef(this), @@ -981,7 +961,7 @@ semantics.addOperation("astOfStatement", { catchBlock, _rbraceCatch, ) { - return createAstNode({ + return createNode({ kind: "statement_try_catch", statements: tryBlock.children.map((s) => s.astOfStatement()), catchName: exitCodeId.astOfExpression(), @@ -1004,7 +984,7 @@ semantics.addOperation("astOfStatement", { ) { checkVariableName(keyId.sourceString, createRef(keyId)); checkVariableName(valueId.sourceString, createRef(valueId)); - return createAstNode({ + return createNode({ kind: "statement_foreach", keyName: keyId.astOfExpression(), valueName: valueId.astOfExpression(), @@ -1024,7 +1004,7 @@ semantics.addOperation("astOfStatement", { expression, _semicolon, ) { - return createAstNode({ + return createNode({ kind: "statement_destruct", type: typeId.astOfType(), identifiers: identifiers @@ -1053,7 +1033,7 @@ semantics.addOperation("astOfStatement", { semantics.addOperation("astOfType", { typeId(firstTactTypeIdCharacter, restOfTactTypeId) { - return createAstNode({ + return createNode({ kind: "type_id", text: firstTactTypeIdCharacter.sourceString + @@ -1062,7 +1042,7 @@ semantics.addOperation("astOfType", { }); }, Type_optional(typeId, _questionMark) { - return createAstNode({ + return createNode({ kind: "optional_type", typeArg: typeId.astOfType(), loc: createRef(this), @@ -1083,7 +1063,7 @@ semantics.addOperation("astOfType", { optValueStorageType, _rangle, ) { - return createAstNode({ + return createNode({ kind: "map_type", keyType: keyTypeId.astOfType(), keyStorageType: unwrapOptNode(optKeyStorageType, (t) => @@ -1097,7 +1077,7 @@ semantics.addOperation("astOfType", { }); }, Type_bounced(_bouncedKwd, _langle, typeId, _rangle) { - return createAstNode({ + return createNode({ kind: "bounced_message_type", messageType: typeId.astOfType(), loc: createRef(this), @@ -1125,7 +1105,7 @@ function baseOfIntLiteral(node: NonterminalNode): AstNumberBase { } function astOfNumber(node: Node): AstNode { - return createAstNode({ + return createNode({ kind: "number", base: baseOfIntLiteral(node), value: bigintOfIntLiteral(node), @@ -1147,38 +1127,38 @@ semantics.addOperation("astOfExpression", { return astOfNumber(this); }, boolLiteral(boolValue) { - return createAstNode({ + return createNode({ kind: "boolean", value: boolValue.sourceString === "true", loc: createRef(this), }); }, id(firstTactIdCharacter, restOfTactId) { - return createAstNode({ + return createNode({ kind: "id", text: firstTactIdCharacter.sourceString + restOfTactId.sourceString, loc: createRef(this), }); }, funcId(firstFuncIdCharacter, restOfFuncId) { - return createAstNode({ + return createNode({ kind: "func_id", text: firstFuncIdCharacter.sourceString + restOfFuncId.sourceString, loc: createRef(this), }); }, null(_nullKwd) { - return createAstNode({ kind: "null", loc: createRef(this) }); + return createNode({ kind: "null", loc: createRef(this) }); }, stringLiteral(_startQuotationMark, string, _endQuotationMark) { - return createAstNode({ + return createNode({ kind: "string", value: string.sourceString, loc: createRef(this), }); }, DestructItem_punned(id) { - return createAstNode({ + return createNode({ kind: "destruct_mapping", field: id.astOfExpression(), name: id.astOfExpression(), @@ -1186,7 +1166,7 @@ semantics.addOperation("astOfExpression", { }); }, DestructItem_regular(idFrom, _colon, id) { - return createAstNode({ + return createNode({ kind: "destruct_mapping", field: idFrom.astOfExpression(), name: id.astOfExpression(), @@ -1194,21 +1174,21 @@ semantics.addOperation("astOfExpression", { }); }, EndOfIdentifiers_regular(_comma) { - return createAstNode({ + return createNode({ kind: "destruct_end", ignoreUnspecifiedFields: false, loc: createRef(this), }); }, EndOfIdentifiers_ignoreUnspecifiedFields(_comma, _dotDot) { - return createAstNode({ + return createNode({ kind: "destruct_end", ignoreUnspecifiedFields: true, loc: createRef(this), }); }, ExpressionAdd_add(left, _plus, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "+", left: left.astOfExpression(), @@ -1217,7 +1197,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionAdd_sub(left, _minus, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "-", left: left.astOfExpression(), @@ -1226,7 +1206,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionMul_div(left, _slash, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "/", left: left.astOfExpression(), @@ -1235,7 +1215,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionMul_mul(left, _star, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "*", left: left.astOfExpression(), @@ -1244,7 +1224,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionMul_rem(left, _percent, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "%", left: left.astOfExpression(), @@ -1253,7 +1233,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionEquality_eq(left, _equalsEquals, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "==", left: left.astOfExpression(), @@ -1262,7 +1242,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionEquality_not(left, _bangEquals, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "!=", left: left.astOfExpression(), @@ -1271,7 +1251,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionCompare_gt(left, _rangle, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: ">", left: left.astOfExpression(), @@ -1280,7 +1260,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionCompare_gte(left, _rangleEquals, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: ">=", left: left.astOfExpression(), @@ -1289,7 +1269,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionCompare_lt(left, _langle, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "<", left: left.astOfExpression(), @@ -1298,7 +1278,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionCompare_lte(left, _langleEquals, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "<=", left: left.astOfExpression(), @@ -1307,7 +1287,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionOr_or(left, _pipePipe, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "||", left: left.astOfExpression(), @@ -1316,7 +1296,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionAnd_and(left, _ampersandAmpersand, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "&&", left: left.astOfExpression(), @@ -1325,7 +1305,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionBitwiseShift_shr(left, _rangleRangle, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: ">>", left: left.astOfExpression(), @@ -1334,7 +1314,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionBitwiseShift_shl(left, _langleLangle, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "<<", left: left.astOfExpression(), @@ -1343,7 +1323,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionBitwiseAnd_bitwiseAnd(left, _ampersand, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "&", left: left.astOfExpression(), @@ -1352,7 +1332,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionBitwiseOr_bitwiseOr(left, _pipe, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "|", left: left.astOfExpression(), @@ -1361,7 +1341,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionBitwiseXor_bitwiseXor(left, _caret, right) { - return createAstNode({ + return createNode({ kind: "op_binary", op: "^", left: left.astOfExpression(), @@ -1372,7 +1352,7 @@ semantics.addOperation("astOfExpression", { // Unary ExpressionUnary_plus(_plus, operand) { - return createAstNode({ + return createNode({ kind: "op_unary", op: "+", operand: operand.astOfExpression(), @@ -1380,7 +1360,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionUnary_minus(_minus, operand) { - return createAstNode({ + return createNode({ kind: "op_unary", op: "-", operand: operand.astOfExpression(), @@ -1388,7 +1368,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionUnary_not(_bang, operand) { - return createAstNode({ + return createNode({ kind: "op_unary", op: "!", operand: operand.astOfExpression(), @@ -1396,7 +1376,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionUnary_bitwiseNot(_tilde, operand) { - return createAstNode({ + return createNode({ kind: "op_unary", op: "~", operand: operand.astOfExpression(), @@ -1407,7 +1387,7 @@ semantics.addOperation("astOfExpression", { return expression.astOfExpression(); }, ExpressionUnboxNotNull(operand, _bangBang) { - return createAstNode({ + return createNode({ kind: "op_unary", op: "!!", operand: operand.astOfExpression(), @@ -1416,7 +1396,7 @@ semantics.addOperation("astOfExpression", { }, ExpressionFieldAccess(source, _dot, fieldId) { - return createAstNode({ + return createNode({ kind: "field_access", aggregate: source.astOfExpression(), field: fieldId.astOfExpression(), @@ -1424,7 +1404,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionMethodCall(source, _dot, methodId, methodArguments) { - return createAstNode({ + return createNode({ kind: "method_call", self: source.astOfExpression(), method: methodId.astOfExpression(), @@ -1433,7 +1413,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionStaticCall(functionId, functionArguments) { - return createAstNode({ + return createNode({ kind: "static_call", function: functionId.astOfExpression(), args: functionArguments.astsOfList(), @@ -1457,7 +1437,7 @@ semantics.addOperation("astOfExpression", { ); } - return createAstNode({ + return createNode({ kind: "struct_instance", type: typeId.astOfType(), args: structFields @@ -1467,7 +1447,7 @@ semantics.addOperation("astOfExpression", { }); }, ExpressionInitOf(_initOfKwd, contractId, initArguments) { - return createAstNode({ + return createNode({ kind: "init_of", contract: contractId.astOfExpression(), args: initArguments.astsOfList(), @@ -1483,7 +1463,7 @@ semantics.addOperation("astOfExpression", { _colon, elseExpression, ) { - return createAstNode({ + return createNode({ kind: "conditional", condition: condition.astOfExpression(), thenBranch: thenExpression.astOfExpression(), @@ -1493,49 +1473,67 @@ semantics.addOperation("astOfExpression", { }, }); -export function parse( - src: string, - path: string, - origin: ItemOrigin, -): AstModule { - return inFile(path, () => { - const matchResult = tactGrammar.match(src); - if (matchResult.failed()) { - throwParseError(matchResult, path, origin); - } - ctx = { origin }; - try { - return semantics(matchResult).astOfModule(); - } finally { - ctx = null; - } - }); -} +export const getParser = (ast: FactoryAst) => { + function parse(src: string, path: string, origin: ItemOrigin): AstModule { + return withContext( + { + currentFile: path, + origin, + createNode: ast.createNode, + }, + () => { + const matchResult = tactGrammar.match(src); + if (matchResult.failed()) { + throwParseError(matchResult, path, origin); + } + return semantics(matchResult).astOfModule(); + }, + ); + } -export function parseExpression(sourceCode: string): AstExpression { - const matchResult = tactGrammar.match(sourceCode, "Expression"); - if (matchResult.failed()) { - throwParseError(matchResult, "", "user"); + function parseExpression(sourceCode: string): AstExpression { + return withContext( + { + currentFile: null, + origin: "user", + createNode: ast.createNode, + }, + () => { + const matchResult = tactGrammar.match(sourceCode, "Expression"); + if (matchResult.failed()) { + throwParseError(matchResult, "", "user"); + } + return semantics(matchResult).astOfExpression(); + }, + ); } - ctx = { origin: "user" }; - return semantics(matchResult).astOfExpression(); -} -export function parseImports( - src: string, - path: string, - origin: ItemOrigin, -): AstImport[] { - return inFile(path, () => { - const matchResult = tactGrammar.match(src, "JustImports"); - if (matchResult.failed()) { - throwParseError(matchResult, path, origin); - } - ctx = { origin }; - try { - return semantics(matchResult).astOfJustImports(); - } finally { - ctx = null; - } - }); -} + function parseImports( + src: string, + path: string, + origin: ItemOrigin, + ): AstImport[] { + return withContext( + { + currentFile: path, + origin, + createNode: ast.createNode, + }, + () => { + const matchResult = tactGrammar.match(src, "JustImports"); + if (matchResult.failed()) { + throwParseError(matchResult, path, origin); + } + return semantics(matchResult).astOfJustImports(); + }, + ); + } + + return { + parse, + parseExpression, + parseImports, + }; +}; + +export type Parser = ReturnType; diff --git a/src/grammar/index.ts b/src/grammar/index.ts new file mode 100644 index 000000000..749fc50c3 --- /dev/null +++ b/src/grammar/index.ts @@ -0,0 +1,3 @@ +export { dummySrcInfo, getParser, Parser } from "./grammar"; + +export { ItemOrigin, SrcInfo } from "./src-info"; diff --git a/src/grammar/rename.ts b/src/grammar/rename.ts index c77210210..391ab1051 100644 --- a/src/grammar/rename.ts +++ b/src/grammar/rename.ts @@ -17,9 +17,9 @@ import { AstNode, AstFunctionAttribute, } from "./ast"; -import { dummySrcInfo } from "./grammar"; import { AstSorter } from "./sort"; import { AstHasher, AstHash } from "./hash"; +import { dummySrcInfo } from "./grammar"; type GivenName = string; diff --git a/src/grammar/src-info.ts b/src/grammar/src-info.ts new file mode 100644 index 000000000..54ddf7248 --- /dev/null +++ b/src/grammar/src-info.ts @@ -0,0 +1,39 @@ +import { Interval as RawInterval } from "ohm-js"; + +export type ItemOrigin = "stdlib" | "user"; + +/** + * Information about source code location (file and interval within it) + * and the source code contents. + */ +export class SrcInfo { + readonly #interval: RawInterval; + readonly #file: string | null; + readonly #origin: ItemOrigin; + + constructor( + interval: RawInterval, + file: string | null, + origin: ItemOrigin, + ) { + this.#interval = interval; + this.#file = file; + this.#origin = origin; + } + + get file() { + return this.#file; + } + + get contents() { + return this.#interval.contents; + } + + get interval() { + return this.#interval; + } + + get origin() { + return this.#origin; + } +} diff --git a/src/grammar/store.ts b/src/grammar/store.ts index db6403278..651a13e4c 100644 --- a/src/grammar/store.ts +++ b/src/grammar/store.ts @@ -8,7 +8,8 @@ import { } from "./ast"; import { throwInternalCompilerError } from "../errors"; import { CompilerContext, createContextStore } from "../context"; -import { ItemOrigin, parse } from "./grammar"; +import { ItemOrigin } from "./src-info"; +import { Parser } from "./grammar"; /** * @public @@ -51,9 +52,12 @@ export function getRawAST(ctx: CompilerContext): AstStore { * Parses multiple Tact source files into AST modules. * @public */ -export function parseModules(sources: TactSource[]): AstModule[] { +export function parseModules( + sources: TactSource[], + parser: Parser, +): AstModule[] { return sources.map((source) => - parse(source.code, source.path, source.origin), + parser.parse(source.code, source.path, source.origin), ); } @@ -68,9 +72,12 @@ export function openContext( ctx: CompilerContext, sources: TactSource[], funcSources: { code: string; path: string }[], + parser: Parser, parsedModules?: AstModule[], ): CompilerContext { - const modules = parsedModules ? parsedModules : parseModules(sources); + const modules = parsedModules + ? parsedModules + : parseModules(sources, parser); const types: AstTypeDecl[] = []; const functions: ( | AstNativeFunctionDecl diff --git a/src/grammar/test/expr-equality.spec.ts b/src/grammar/test/expr-equality.spec.ts index ff10859c6..77fe1614e 100644 --- a/src/grammar/test/expr-equality.spec.ts +++ b/src/grammar/test/expr-equality.spec.ts @@ -1,5 +1,5 @@ -import { __DANGER_resetNodeId, eqExpressions } from "../ast"; -import { parseExpression } from "../grammar"; +import { eqExpressions, getAstFactory } from "../ast"; +import { getParser } from "../"; type Test = { expr1: string; expr2: string; equality: boolean }; @@ -366,15 +366,14 @@ const initOfExpressions: Test[] = [ ]; function testEquality(expr1: string, expr2: string, equal: boolean) { + const ast = getAstFactory(); + const { parseExpression } = getParser(ast); expect(eqExpressions(parseExpression(expr1), parseExpression(expr2))).toBe( equal, ); } describe("expression-equality", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); it("should correctly determine if two expressions involving values are equal or not.", () => { valueExpressions.forEach((test) => { testEquality(test.expr1, test.expr2, test.equality); diff --git a/src/grammar/test/expr-is-value.spec.ts b/src/grammar/test/expr-is-value.spec.ts index f26a23535..0ebee15cb 100644 --- a/src/grammar/test/expr-is-value.spec.ts +++ b/src/grammar/test/expr-is-value.spec.ts @@ -1,7 +1,7 @@ //type Test = { expr: string; isValue: boolean }; -import { __DANGER_resetNodeId, isValue } from "../ast"; -import { parseExpression } from "../grammar"; +import { getAstFactory, isValue } from "../ast"; +import { getParser } from "../"; const valueExpressions: string[] = [ "1", @@ -52,13 +52,12 @@ const notValueExpressions: string[] = [ ]; function testIsValue(expr: string, testResult: boolean) { + const ast = getAstFactory(); + const { parseExpression } = getParser(ast); expect(isValue(parseExpression(expr))).toBe(testResult); } describe("expression-is-value", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); valueExpressions.forEach((test) => { it(`should correctly determine that '${test}' is a value expression.`, () => { testIsValue(test, true); diff --git a/src/imports/resolveImports.spec.ts b/src/imports/resolveImports.spec.ts index 25b805cbf..8b65f14da 100644 --- a/src/imports/resolveImports.spec.ts +++ b/src/imports/resolveImports.spec.ts @@ -1,6 +1,8 @@ import { resolveImports } from "./resolveImports"; import { createNodeFileSystem } from "../vfs/createNodeFileSystem"; import path from "path"; +import { getParser } from "../grammar"; +import { getAstFactory } from "../grammar/ast"; describe("resolveImports", () => { it("should resolve imports", () => { @@ -10,10 +12,12 @@ describe("resolveImports", () => { const stdlib = createNodeFileSystem( path.resolve(__dirname, "__testdata", "stdlib"), ); + const ast = getAstFactory(); const resolved = resolveImports({ project, stdlib, entrypoint: "./main.tact", + parser: getParser(ast), }); expect(resolved).toMatchObject({ func: [ diff --git a/src/imports/resolveImports.ts b/src/imports/resolveImports.ts index 0dd8915af..b88c066cf 100644 --- a/src/imports/resolveImports.ts +++ b/src/imports/resolveImports.ts @@ -1,4 +1,4 @@ -import { ItemOrigin, parseImports } from "../grammar/grammar"; +import { ItemOrigin, Parser } from "../grammar"; import { VirtualFileSystem } from "../vfs/VirtualFileSystem"; import { throwCompilationError } from "../errors"; import { resolveLibrary } from "./resolveLibrary"; @@ -7,6 +7,7 @@ export function resolveImports(args: { entrypoint: string; project: VirtualFileSystem; stdlib: VirtualFileSystem; + parser: Parser; }) { // // Load stdlib and entrypoint @@ -40,7 +41,7 @@ export function resolveImports(args: { const processed: Set = new Set(); const pending: { code: string; path: string; origin: ItemOrigin }[] = []; function processImports(source: string, path: string, origin: ItemOrigin) { - const imp = parseImports(source, path, origin); + const imp = args.parser.parseImports(source, path, origin); for (const i of imp) { const importPath = i.path.value; // Resolve library diff --git a/src/interpreter.ts b/src/interpreter.ts index 8a40737c3..66424fe2a 100644 --- a/src/interpreter.ts +++ b/src/interpreter.ts @@ -31,6 +31,7 @@ import { AstOpBinary, AstOpUnary, AstPrimitiveTypeDecl, + FactoryAst, AstStatement, AstStatementAssign, AstStatementAugmentedAssign, @@ -51,10 +52,11 @@ import { AstTrait, AstUnaryOperation, eqNames, + getAstFactory, idText, isSelfId, } from "./grammar/ast"; -import { SrcInfo, dummySrcInfo, parseExpression } from "./grammar/grammar"; +import { SrcInfo, dummySrcInfo, Parser, getParser } from "./grammar"; import { divFloor, modFloor } from "./optimizer/util"; import { getStaticConstant, @@ -598,9 +600,13 @@ class EnvironmentStack { } } -export function parseAndEvalExpression(sourceCode: string): EvalResult { +export function parseAndEvalExpression( + sourceCode: string, + ast: FactoryAst = getAstFactory(), + parser: Parser = getParser(ast), +): EvalResult { try { - const ast = parseExpression(sourceCode); + const ast = parser.parseExpression(sourceCode); const constEvalResult = evalConstantExpression( ast, new CompilerContext(), diff --git a/src/optimizer/algebraic.ts b/src/optimizer/algebraic.ts index 27c1bdb86..bfde705b7 100644 --- a/src/optimizer/algebraic.ts +++ b/src/optimizer/algebraic.ts @@ -13,9 +13,6 @@ import { checkIsName, checkIsNumber, checkIsUnaryOpNode, - makeBinaryExpression, - makeUnaryExpression, - makeValueExpression, } from "./util"; export class AddZero extends Rule { @@ -23,7 +20,7 @@ export class AddZero extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -49,7 +46,7 @@ export class AddZero extends Rule { const op = topLevelNode.op; if (op === "-") { - return makeUnaryExpression("-", x); + return util.makeUnaryExpression("-", x); } else { return x; } @@ -66,7 +63,7 @@ export class AddZero extends Rule { export class MultiplyZero extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -78,7 +75,7 @@ export class MultiplyZero extends Rule { // The tree has this form: // x * 0, where x is an identifier - return makeValueExpression(0n); + return util.makeValueExpression(0n); } else if ( checkIsNumber(topLevelNode.left, 0n) && checkIsName(topLevelNode.right) @@ -86,7 +83,7 @@ export class MultiplyZero extends Rule { // The tree has this form: // 0 * x, where x is an identifier - return makeValueExpression(0n); + return util.makeValueExpression(0n); } } } @@ -138,7 +135,7 @@ export class MultiplyOne extends Rule { export class SubtractSelf extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -155,7 +152,7 @@ export class SubtractSelf extends Rule { const y = topLevelNode.right; if (eqExpressions(x, y)) { - return makeValueExpression(0n); + return util.makeValueExpression(0n); } } } @@ -170,7 +167,7 @@ export class SubtractSelf extends Rule { export class AddSelf extends Rule { public applyRule( ast: AstExpression, - optimizer: ExpressionTransformer, + { applyRules, util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -187,14 +184,14 @@ export class AddSelf extends Rule { const y = topLevelNode.right; if (eqExpressions(x, y)) { - const res = makeBinaryExpression( + const res = util.makeBinaryExpression( "*", x, - makeValueExpression(2n), + util.makeValueExpression(2n), ); // Since we joined the tree, there is further opportunity // for simplification - return optimizer.applyRules(res); + return applyRules(res); } } } @@ -209,7 +206,7 @@ export class AddSelf extends Rule { export class OrTrue extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -222,12 +219,12 @@ export class OrTrue extends Rule { // The tree has this form: // x || true, where x is an identifier or a value - return makeValueExpression(true); + return util.makeValueExpression(true); } else if (checkIsBoolean(topLevelNode.left, true)) { // The tree has this form: // true || x - return makeValueExpression(true); + return util.makeValueExpression(true); } } } @@ -241,7 +238,7 @@ export class OrTrue extends Rule { export class AndFalse extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -254,12 +251,12 @@ export class AndFalse extends Rule { // The tree has this form: // x && false, where x is an identifier or a value - return makeValueExpression(false); + return util.makeValueExpression(false); } else if (checkIsBoolean(topLevelNode.left, false)) { // The tree has this form: // false && x - return makeValueExpression(false); + return util.makeValueExpression(false); } } } @@ -391,7 +388,7 @@ export class AndSelf extends Rule { export class ExcludedMiddle extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -411,7 +408,7 @@ export class ExcludedMiddle extends Rule { (checkIsName(x) || isValue(x)) && eqExpressions(x, y) ) { - return makeValueExpression(true); + return util.makeValueExpression(true); } } } else if (checkIsUnaryOpNode(topLevelNode.left)) { @@ -429,7 +426,7 @@ export class ExcludedMiddle extends Rule { (checkIsName(x) || isValue(x)) && eqExpressions(x, y) ) { - return makeValueExpression(true); + return util.makeValueExpression(true); } } } @@ -445,7 +442,7 @@ export class ExcludedMiddle extends Rule { export class Contradiction extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -465,7 +462,7 @@ export class Contradiction extends Rule { (checkIsName(x) || isValue(x)) && eqExpressions(x, y) ) { - return makeValueExpression(false); + return util.makeValueExpression(false); } } } else if (checkIsUnaryOpNode(topLevelNode.left)) { @@ -483,7 +480,7 @@ export class Contradiction extends Rule { (checkIsName(x) || isValue(x)) && eqExpressions(x, y) ) { - return makeValueExpression(false); + return util.makeValueExpression(false); } } } @@ -527,7 +524,7 @@ export class DoubleNegation extends Rule { export class NegateTrue extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsUnaryOpNode(ast)) { const topLevelNode = ast as AstOpUnary; @@ -536,7 +533,7 @@ export class NegateTrue extends Rule { // The tree has this form // !true - return makeValueExpression(false); + return util.makeValueExpression(false); } } } @@ -550,7 +547,7 @@ export class NegateTrue extends Rule { export class NegateFalse extends Rule { public applyRule( ast: AstExpression, - _optimizer: ExpressionTransformer, + { util }: ExpressionTransformer, ): AstExpression { if (checkIsUnaryOpNode(ast)) { const topLevelNode = ast as AstOpUnary; @@ -559,7 +556,7 @@ export class NegateFalse extends Rule { // The tree has this form // !false - return makeValueExpression(true); + return util.makeValueExpression(true); } } } diff --git a/src/optimizer/associative.ts b/src/optimizer/associative.ts index 865872588..bdb26e4d9 100644 --- a/src/optimizer/associative.ts +++ b/src/optimizer/associative.ts @@ -16,9 +16,8 @@ import { checkIsBinaryOp_With_RightValue, checkIsBinaryOp_With_LeftValue, extractValue, - makeBinaryExpression, - makeValueExpression, sign, + AstUtil, } from "./util"; type TransformData = { @@ -26,7 +25,12 @@ type TransformData = { safetyCondition: boolean; }; -type Transform = (x1: AstExpression, c1: Value, c2: Value) => TransformData; +type Transform = ( + x1: AstExpression, + c1: Value, + c2: Value, + util: AstUtil, +) => TransformData; /* A simple wrapper function to transform the right value in a binary operator to a continuation so that we can call the evaluation function in the interpreter module @@ -122,7 +126,7 @@ abstract class AllowableOpRule extends AssociativeRewriteRule { export class AssociativeRule1 extends AllowableOpRule { public applyRule( ast: AstExpression, - optimizer: ExpressionTransformer, + { applyRules, util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -171,11 +175,11 @@ export class AssociativeRule1 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newLeft = optimizer.applyRules( - makeBinaryExpression(op1, x1, x2), + const newLeft = applyRules( + util.makeBinaryExpression(op1, x1, x2), ); - const newRight = makeValueExpression(val); - return makeBinaryExpression(op, newLeft, newRight); + const newRight = util.makeValueExpression(val); + return util.makeBinaryExpression(op, newLeft, newRight); } catch (e) { // Do nothing: will exit rule without modifying tree } @@ -225,11 +229,11 @@ export class AssociativeRule1 extends AllowableOpRule { // Because we are joining x1 and val, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newValNode = makeValueExpression(val); - const newLeft = optimizer.applyRules( - makeBinaryExpression(op1, x1, newValNode), + const newValNode = util.makeValueExpression(val); + const newLeft = applyRules( + util.makeBinaryExpression(op1, x1, newValNode), ); - return makeBinaryExpression(op2, newLeft, x2); + return util.makeBinaryExpression(op2, newLeft, x2); } catch (e) { // Do nothing: will exit rule without modifying tree } @@ -281,11 +285,11 @@ export class AssociativeRule1 extends AllowableOpRule { // Because we are joining x2 and val, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newValNode = makeValueExpression(val); - const newLeft = optimizer.applyRules( - makeBinaryExpression(op2, x2, newValNode), + const newValNode = util.makeValueExpression(val); + const newLeft = applyRules( + util.makeBinaryExpression(op2, x2, newValNode), ); - return makeBinaryExpression(op1, newLeft, x1); + return util.makeBinaryExpression(op1, newLeft, x1); } catch (e) { // Do nothing: will exit rule without modifying tree } @@ -335,11 +339,11 @@ export class AssociativeRule1 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newRight = optimizer.applyRules( - makeBinaryExpression(op2, x1, x2), + const newRight = applyRules( + util.makeBinaryExpression(op2, x1, x2), ); - const newLeft = makeValueExpression(val); - return makeBinaryExpression(op, newLeft, newRight); + const newLeft = util.makeValueExpression(val); + return util.makeBinaryExpression(op, newLeft, newRight); } catch (e) { // Do nothing: will exit rule without modifying tree } @@ -360,7 +364,7 @@ export class AssociativeRule1 extends AllowableOpRule { export class AssociativeRule2 extends AllowableOpRule { public applyRule( ast: AstExpression, - optimizer: ExpressionTransformer, + { applyRules, util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -396,10 +400,10 @@ export class AssociativeRule2 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newLeft = optimizer.applyRules( - makeBinaryExpression(op1, x1, x2), + const newLeft = applyRules( + util.makeBinaryExpression(op1, x1, x2), ); - return makeBinaryExpression(op, newLeft, c1); + return util.makeBinaryExpression(op, newLeft, c1); } } else if ( checkIsBinaryOp_With_LeftValue(topLevelNode.left) && @@ -431,10 +435,10 @@ export class AssociativeRule2 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newRight = optimizer.applyRules( - makeBinaryExpression(op, x1, x2), + const newRight = applyRules( + util.makeBinaryExpression(op, x1, x2), ); - return makeBinaryExpression(op1, c1, newRight); + return util.makeBinaryExpression(op1, c1, newRight); } } else if ( !isValue(topLevelNode.left) && @@ -466,10 +470,10 @@ export class AssociativeRule2 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newLeft = optimizer.applyRules( - makeBinaryExpression(op, x2, x1), + const newLeft = applyRules( + util.makeBinaryExpression(op, x2, x1), ); - return makeBinaryExpression(op1, newLeft, c1); + return util.makeBinaryExpression(op1, newLeft, c1); } } else if ( !isValue(topLevelNode.left) && @@ -503,10 +507,10 @@ export class AssociativeRule2 extends AllowableOpRule { // Because we are joining x1 and x2, // there is further opportunity of simplification, // So, we ask the evaluator to apply all the rules in the subtree. - const newRight = optimizer.applyRules( - makeBinaryExpression(op1, x2, x1), + const newRight = applyRules( + util.makeBinaryExpression(op1, x2, x1), ); - return makeBinaryExpression(op, c1, newRight); + return util.makeBinaryExpression(op, c1, newRight); } } } @@ -622,12 +626,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: (x1 + c1) + c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression: x1 + (c1 + c2) const val_ = evalBinaryOp("+", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", x1, val_node, @@ -644,12 +648,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: (x1 + c1) - c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression: x1 + (c1 - c2) const val_ = evalBinaryOp("-", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", x1, val_node, @@ -671,12 +675,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: (x1 - c1) + c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 - (c1 - c2) const val_ = evalBinaryOp("-", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", x1, val_node, @@ -693,12 +697,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: (x1 - c1) - c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 - (c1 + c2) const val_ = evalBinaryOp("+", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", x1, val_node, @@ -720,12 +724,12 @@ export class AssociativeRule3 extends Rule { [ "*", // original expression: (x1 * c1) * c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 * (c1 * c2) const val_ = evalBinaryOp("*", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "*", x1, val_node, @@ -747,12 +751,12 @@ export class AssociativeRule3 extends Rule { [ "&&", // original expression: (x1 && c1) && c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 && (c1 && c2) const val_ = evalBinaryOp("&&", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "&&", x1, val_node, @@ -770,12 +774,12 @@ export class AssociativeRule3 extends Rule { [ "||", // original expression: (x1 || c1) || c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 || (c1 || c2) const val_ = evalBinaryOp("||", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "||", x1, val_node, @@ -807,12 +811,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: c2 + (c1 + x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 + c1) + x1 const val_ = evalBinaryOp("+", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", val_node, x1, @@ -829,12 +833,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: c2 + (c1 - x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 + c1) - x1 const val_ = evalBinaryOp("+", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -856,12 +860,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: c2 - (c1 + x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 - c1) - x1 const val_ = evalBinaryOp("-", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -878,12 +882,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: c2 - (c1 - x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 - c1) + x1 const val_ = evalBinaryOp("-", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", val_node, x1, @@ -906,12 +910,12 @@ export class AssociativeRule3 extends Rule { "*", // original expression: c2 * (c1 * x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 * c1) * x1 const val_ = evalBinaryOp("*", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "*", val_node, x1, @@ -934,12 +938,12 @@ export class AssociativeRule3 extends Rule { "&&", // original expression: c2 && (c1 && x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 && c1) && x1 const val_ = evalBinaryOp("&&", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "&&", val_node, x1, @@ -958,12 +962,12 @@ export class AssociativeRule3 extends Rule { "||", // original expression: c2 || (c1 || x1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 || c1) || x1 const val_ = evalBinaryOp("||", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "||", val_node, x1, @@ -995,12 +999,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: c2 + (x1 + c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 + (c2 + c1) const val_ = evalBinaryOp("+", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", x1, val_node, @@ -1017,12 +1021,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: c2 + (x1 - c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression x1 - (c1 - c2) const val_ = evalBinaryOp("-", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", x1, val_node, @@ -1044,12 +1048,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: c2 - (x1 + c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 - c1) - x1 const val_ = evalBinaryOp("-", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -1066,12 +1070,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: c2 - (x1 - c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // final expression (c2 + c1) - x1 const val_ = evalBinaryOp("+", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -1094,12 +1098,12 @@ export class AssociativeRule3 extends Rule { [ "*", // original expression: c2 * (x1 * c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression x1 * (c2 * c1) const val_ = evalBinaryOp("*", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "*", x1, val_node, @@ -1120,13 +1124,13 @@ export class AssociativeRule3 extends Rule { [ "&&", // original expression: c2 && (x1 && c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { const val_ = evalBinaryOp("&&", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); let final_expr; if (c2 === true) { // Final expression x1 && (c2 && c1) - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "&&", x1, val_node, @@ -1136,7 +1140,7 @@ export class AssociativeRule3 extends Rule { // Note that by the safety condition, // at this point c1 = true. - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "&&", val_node, x1, @@ -1157,13 +1161,13 @@ export class AssociativeRule3 extends Rule { [ "||", // original expression: c2 || (x1 || c1) - (x1, c1, c2) => { + (x1, c1, c2, util) => { const val_ = evalBinaryOp("||", c2, c1); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); let final_expr; if (c2 === false) { // Final expression x1 || (c2 || c1) - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "||", x1, val_node, @@ -1173,7 +1177,7 @@ export class AssociativeRule3 extends Rule { // Note that by the safety condition, // at this point c1 = false. - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "||", val_node, x1, @@ -1207,12 +1211,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: (c1 + x1) + c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression (c1 + c2) + x1 const val_ = evalBinaryOp("+", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", val_node, x1, @@ -1229,12 +1233,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: (c1 + x1) - c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression (c1 - c2) + x1 const val_ = evalBinaryOp("-", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "+", val_node, x1, @@ -1256,12 +1260,12 @@ export class AssociativeRule3 extends Rule { [ "+", // original expression: (c1 - x1) + c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression (c1 + c2) - x1 const val_ = evalBinaryOp("+", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -1278,12 +1282,12 @@ export class AssociativeRule3 extends Rule { [ "-", // original expression: (c1 - x1) - c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression (c1 - c2) - x1 const val_ = evalBinaryOp("-", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "-", val_node, x1, @@ -1305,12 +1309,12 @@ export class AssociativeRule3 extends Rule { [ "*", // original expression: (c1 * x1) * c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { // Final expression (c1 * c2) * x1 const val_ = evalBinaryOp("*", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); return { - simplifiedExpression: makeBinaryExpression( + simplifiedExpression: util.makeBinaryExpression( "*", val_node, x1, @@ -1332,13 +1336,13 @@ export class AssociativeRule3 extends Rule { [ "&&", // original expression: (c1 && x1) && c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { const val_ = evalBinaryOp("&&", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); let final_expr; if (c2 === true) { // Final expression (c1 && c2) && x1 - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "&&", val_node, x1, @@ -1348,7 +1352,7 @@ export class AssociativeRule3 extends Rule { // Note that by the safety condition, // at this point c1 = true. - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "&&", x1, val_node, @@ -1369,13 +1373,13 @@ export class AssociativeRule3 extends Rule { [ "||", // original expression: (c1 || x1) || c2 - (x1, c1, c2) => { + (x1, c1, c2, util) => { const val_ = evalBinaryOp("||", c1, c2); - const val_node = makeValueExpression(val_); + const val_node = util.makeValueExpression(val_); let final_expr; if (c2 === false) { // Final expression (c1 || c2) || x1 - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "||", val_node, x1, @@ -1385,7 +1389,7 @@ export class AssociativeRule3 extends Rule { // Note that by the safety condition, // at this point c1 = false. - final_expr = makeBinaryExpression( + final_expr = util.makeBinaryExpression( "||", x1, val_node, @@ -1456,7 +1460,7 @@ export class AssociativeRule3 extends Rule { public applyRule( ast: AstExpression, - optimizer: ExpressionTransformer, + { applyRules, util }: ExpressionTransformer, ): AstExpression { if (checkIsBinaryOpNode(ast)) { const topLevelNode = ast as AstOpBinary; @@ -1483,12 +1487,13 @@ export class AssociativeRule3 extends Rule { x1, c1, c2, + util, ); if (data.safetyCondition) { // Since the tree is simpler now, there is further // opportunity for simplification that was missed // previously - return optimizer.applyRules(data.simplifiedExpression); + return applyRules(data.simplifiedExpression); } } catch (e) { // Do nothing: will exit rule without modifying tree @@ -1516,12 +1521,13 @@ export class AssociativeRule3 extends Rule { x1, c1, c2, + util, ); if (data.safetyCondition) { // Since the tree is simpler now, there is further // opportunity for simplification that was missed // previously - return optimizer.applyRules(data.simplifiedExpression); + return applyRules(data.simplifiedExpression); } } catch (e) { // Do nothing: will exit rule without modifying tree @@ -1549,12 +1555,13 @@ export class AssociativeRule3 extends Rule { x1, c1, c2, + util, ); if (data.safetyCondition) { // Since the tree is simpler now, there is further // opportunity for simplification that was missed // previously - return optimizer.applyRules(data.simplifiedExpression); + return applyRules(data.simplifiedExpression); } } catch (e) { // Do nothing: will exit rule without modifying tree @@ -1582,12 +1589,13 @@ export class AssociativeRule3 extends Rule { x1, c1, c2, + util, ); if (data.safetyCondition) { // Since the tree is simpler now, there is further // opportunity for simplification that was missed // previously - return optimizer.applyRules(data.simplifiedExpression); + return applyRules(data.simplifiedExpression); } } catch (e) { // Do nothing: will exit rule without modifying tree diff --git a/src/optimizer/standardOptimizer.ts b/src/optimizer/standardOptimizer.ts index 1558be4b2..7df2b4ed9 100644 --- a/src/optimizer/standardOptimizer.ts +++ b/src/optimizer/standardOptimizer.ts @@ -23,16 +23,15 @@ import { AssociativeRule3, } from "./associative"; import { Rule, ExpressionTransformer } from "./types"; +import { AstUtil } from "./util"; type PrioritizedRule = { priority: number; rule: Rule }; // This optimizer uses rules that preserve overflows in integer expressions. -export class StandardOptimizer extends ExpressionTransformer { +export class StandardOptimizer implements ExpressionTransformer { private rules: PrioritizedRule[]; - constructor() { - super(); - + constructor(public util: AstUtil) { this.rules = [ { priority: 0, rule: new AssociativeRule1() }, { priority: 1, rule: new AssociativeRule2() }, @@ -60,11 +59,11 @@ export class StandardOptimizer extends ExpressionTransformer { this.rules.sort((r1, r2) => r1.priority - r2.priority); } - public applyRules(ast: AstExpression): AstExpression { + public applyRules = (ast: AstExpression): AstExpression => { return this.rules.reduce( (prev, prioritizedRule) => prioritizedRule.rule.applyRule(prev, this), ast, ); - } + }; } diff --git a/src/optimizer/test/partial-eval.spec.ts b/src/optimizer/test/partial-eval.spec.ts index db01b6c09..77c1e4ac0 100644 --- a/src/optimizer/test/partial-eval.spec.ts +++ b/src/optimizer/test/partial-eval.spec.ts @@ -1,18 +1,18 @@ import { AstExpression, + FactoryAst, AstValue, - __DANGER_resetNodeId, - cloneAstNode, eqExpressions, + getAstFactory, isValue, } from "../../grammar/ast"; -import { parseExpression } from "../../grammar/grammar"; -import { extractValue, makeValueExpression } from "../util"; -import { partiallyEvalExpression } from "../../constEval"; +import { AstUtil, extractValue, getAstUtil } from "../util"; +import { getOptimizer } from "../../constEval"; import { CompilerContext } from "../../context"; import { ExpressionTransformer, Rule } from "../types"; import { AssociativeRule3 } from "../associative"; import { evalBinaryOp, evalUnaryOp } from "../../interpreter"; +import { getParser } from "../../grammar"; const MAX: string = "115792089237316195423570985008687907853269984665640564039457584007913129639935"; @@ -315,15 +315,21 @@ const booleanExpressions = [ function testExpression(original: string, simplified: string) { it(`should simplify ${original} to ${simplified}`, () => { - expect( - eqExpressions( - partiallyEvalExpression( - parseExpression(original), - new CompilerContext(), - ), - dummyEval(parseExpression(simplified)), - ), - ).toBe(true); + const ast = getAstFactory(); + const { parseExpression } = getParser(ast); + const util = getAstUtil(ast); + const { partiallyEvalExpression } = getOptimizer(util); + const originalValue = partiallyEvalExpression( + parseExpression(original), + new CompilerContext(), + ); + const simplifiedValue = dummyEval( + parseExpression(simplified), + ast, + util, + ); + const areMatching = eqExpressions(originalValue, simplifiedValue); + expect(areMatching).toBe(true); }); } @@ -333,12 +339,19 @@ function testExpressionWithOptimizer( optimizer: ExpressionTransformer, ) { it(`should simplify ${original} to ${simplified}`, () => { - expect( - eqExpressions( - optimizer.applyRules(dummyEval(parseExpression(original))), - dummyEval(parseExpression(simplified)), - ), - ).toBe(true); + const ast = getAstFactory(); + const { parseExpression } = getParser(ast); + const util = getAstUtil(ast); + const originalValue = optimizer.applyRules( + dummyEval(parseExpression(original), ast, util), + ); + const simplifiedValue = dummyEval( + parseExpression(simplified), + ast, + util, + ); + const areMatching = eqExpressions(originalValue, simplifiedValue); + expect(areMatching).toBe(true); }); } @@ -347,101 +360,115 @@ function testExpressionWithOptimizer( // The reason for doing this is that the partial evaluator will actually simplify constant // expressions. So, when comparing for equality of expressions, we also need to simplify // constant expressions. -function dummyEval(ast: AstExpression): AstExpression { - let newNode: AstExpression; - switch (ast.kind) { - case "null": - return ast; - case "boolean": - return ast; - case "number": - return ast; - case "string": - return ast; - case "id": - return ast; - case "method_call": - newNode = cloneAstNode(ast); - newNode.args = ast.args.map(dummyEval); - newNode.self = dummyEval(ast.self); - return newNode; - case "init_of": - newNode = cloneAstNode(ast); - newNode.args = ast.args.map(dummyEval); - return newNode; - case "op_unary": - newNode = cloneAstNode(ast); - newNode.operand = dummyEval(ast.operand); - if (isValue(newNode.operand)) { - return makeValueExpression( - evalUnaryOp( - ast.op, - extractValue(newNode.operand as AstValue), - ), - ); +function dummyEval( + ast: AstExpression, + { cloneNode }: FactoryAst, + { makeValueExpression }: AstUtil, +): AstExpression { + const recurse = (ast: AstExpression): AstExpression => { + switch (ast.kind) { + case "null": + return ast; + case "boolean": + return ast; + case "number": + return ast; + case "string": + return ast; + case "id": + return ast; + case "method_call": { + const newNode = cloneNode(ast); + newNode.args = ast.args.map(recurse); + newNode.self = recurse(ast.self); + return newNode; } - return newNode; - case "op_binary": - newNode = cloneAstNode(ast); - newNode.left = dummyEval(ast.left); - newNode.right = dummyEval(ast.right); - if (isValue(newNode.left) && isValue(newNode.right)) { - const valR = extractValue(newNode.right as AstValue); - return makeValueExpression( - evalBinaryOp( - ast.op, - extractValue(newNode.left as AstValue), - () => valR, - ), - ); + case "init_of": { + const newNode = cloneNode(ast); + newNode.args = ast.args.map(recurse); + return newNode; } - return newNode; - case "conditional": - newNode = cloneAstNode(ast); - newNode.thenBranch = dummyEval(ast.thenBranch); - newNode.elseBranch = dummyEval(ast.elseBranch); - return newNode; - case "struct_instance": - newNode = cloneAstNode(ast); - newNode.args = ast.args.map((param) => { - const newParam = cloneAstNode(param); - newParam.initializer = dummyEval(param.initializer); - return newParam; - }); - return newNode; - case "field_access": - newNode = cloneAstNode(ast); - newNode.aggregate = dummyEval(ast.aggregate); - return newNode; - case "static_call": - newNode = cloneAstNode(ast); - newNode.args = ast.args.map(dummyEval); - return newNode; - } + case "op_unary": { + const newNode = cloneNode(ast); + newNode.operand = recurse(ast.operand); + if (isValue(newNode.operand)) { + return makeValueExpression( + evalUnaryOp( + ast.op, + extractValue(newNode.operand as AstValue), + ), + ); + } + return newNode; + } + case "op_binary": { + const newNode = cloneNode(ast); + newNode.left = recurse(ast.left); + newNode.right = recurse(ast.right); + if (isValue(newNode.left) && isValue(newNode.right)) { + const valR = extractValue(newNode.right as AstValue); + return makeValueExpression( + evalBinaryOp( + ast.op, + extractValue(newNode.left as AstValue), + () => valR, + ), + ); + } + return newNode; + } + case "conditional": { + const newNode = cloneNode(ast); + newNode.thenBranch = recurse(ast.thenBranch); + newNode.elseBranch = recurse(ast.elseBranch); + return newNode; + } + case "struct_instance": { + const newNode = cloneNode(ast); + newNode.args = ast.args.map((param) => { + const newParam = cloneNode(param); + newParam.initializer = recurse(param.initializer); + return newParam; + }); + return newNode; + } + case "field_access": { + const newNode = cloneNode(ast); + newNode.aggregate = recurse(ast.aggregate); + return newNode; + } + case "static_call": { + const newNode = cloneNode(ast); + newNode.args = ast.args.map(recurse); + return newNode; + } + } + }; + + return recurse(ast); } // A dummy optimizer to test specific rules -class ParameterizableDummyOptimizer extends ExpressionTransformer { +class ParameterizableDummyOptimizer implements ExpressionTransformer { private rules: Rule[]; - constructor(rules: Rule[]) { - super(); + public util: AstUtil; + + constructor(rules: Rule[], Ast: FactoryAst) { + this.util = getAstUtil(Ast); this.rules = rules; } - public applyRules(ast: AstExpression): AstExpression { + public applyRules = (ast: AstExpression): AstExpression => { return this.rules.reduce( (prev, rule) => rule.applyRule(prev, this), ast, ); - } + }; } describe("partial-evaluator", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); additiveExpressions.forEach((test) => { testExpression(test.original, test.simplified); }); @@ -452,13 +479,14 @@ describe("partial-evaluator", () => { testExpression(test.original, test.simplified); }); - // For the following cases, we need an optimizer that only - // uses the associative rule 3. - const optimizer = new ParameterizableDummyOptimizer([ - new AssociativeRule3(), - ]); - associativeRuleExpressions.forEach((test) => { + // For the following cases, we need an optimizer that only + // uses the associative rule 3. + const optimizer = new ParameterizableDummyOptimizer( + [new AssociativeRule3()], + getAstFactory(), + ); + testExpressionWithOptimizer(test.original, test.simplified, optimizer); }); diff --git a/src/optimizer/types.ts b/src/optimizer/types.ts index bbd871efb..d4011c138 100644 --- a/src/optimizer/types.ts +++ b/src/optimizer/types.ts @@ -1,7 +1,9 @@ import { AstExpression } from "../grammar/ast"; +import { AstUtil } from "./util"; -export abstract class ExpressionTransformer { - public abstract applyRules(ast: AstExpression): AstExpression; +export interface ExpressionTransformer { + util: AstUtil; + applyRules(ast: AstExpression): AstExpression; } export abstract class Rule { diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts index 4b98ab779..dc6ff85bd 100644 --- a/src/optimizer/util.ts +++ b/src/optimizer/util.ts @@ -2,11 +2,11 @@ import { AstExpression, AstUnaryOperation, AstBinaryOperation, - createAstNode, AstValue, isValue, + FactoryAst, } from "../grammar/ast"; -import { dummySrcInfo } from "../grammar/grammar"; +import { dummySrcInfo } from "../grammar"; import { throwInternalCompilerError } from "../errors"; import { Value } from "../types/types"; @@ -25,71 +25,81 @@ export function extractValue(ast: AstValue): Value { } } -export function makeValueExpression(value: Value): AstValue { - if (value === null) { - const result = createAstNode({ - kind: "null", - loc: dummySrcInfo, - }); - return result as AstValue; - } - if (typeof value === "string") { - const result = createAstNode({ - kind: "string", - value: value, - loc: dummySrcInfo, - }); - return result as AstValue; +export const getAstUtil = ({ createNode }: FactoryAst) => { + function makeValueExpression(value: Value): AstValue { + if (value === null) { + const result = createNode({ + kind: "null", + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "string") { + const result = createNode({ + kind: "string", + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "bigint") { + const result = createNode({ + kind: "number", + base: 10, + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + if (typeof value === "boolean") { + const result = createNode({ + kind: "boolean", + value: value, + loc: dummySrcInfo, + }); + return result as AstValue; + } + throwInternalCompilerError( + `structs, addresses, cells, and comment values are not supported at the moment.`, + ); } - if (typeof value === "bigint") { - const result = createAstNode({ - kind: "number", - base: 10, - value: value, + + function makeUnaryExpression( + op: AstUnaryOperation, + operand: AstExpression, + ): AstExpression { + const result = createNode({ + kind: "op_unary", + op: op, + operand: operand, loc: dummySrcInfo, }); - return result as AstValue; + return result as AstExpression; } - if (typeof value === "boolean") { - const result = createAstNode({ - kind: "boolean", - value: value, + + function makeBinaryExpression( + op: AstBinaryOperation, + left: AstExpression, + right: AstExpression, + ): AstExpression { + const result = createNode({ + kind: "op_binary", + op: op, + left: left, + right: right, loc: dummySrcInfo, }); - return result as AstValue; + return result as AstExpression; } - throwInternalCompilerError( - `structs, addresses, cells, and comment values are not supported at the moment.`, - ); -} -export function makeUnaryExpression( - op: AstUnaryOperation, - operand: AstExpression, -): AstExpression { - const result = createAstNode({ - kind: "op_unary", - op: op, - operand: operand, - loc: dummySrcInfo, - }); - return result as AstExpression; -} + return { + makeValueExpression, + makeUnaryExpression, + makeBinaryExpression, + }; +}; -export function makeBinaryExpression( - op: AstBinaryOperation, - left: AstExpression, - right: AstExpression, -): AstExpression { - const result = createAstNode({ - kind: "op_binary", - op: op, - left: left, - right: right, - loc: dummySrcInfo, - }); - return result as AstExpression; -} +export type AstUtil = ReturnType; // Checks if the top level node is an unary op node export function checkIsUnaryOpNode(ast: AstExpression): boolean { diff --git a/src/pipeline/build.ts b/src/pipeline/build.ts index 91d93131a..f892003ca 100644 --- a/src/pipeline/build.ts +++ b/src/pipeline/build.ts @@ -19,8 +19,9 @@ import { VirtualFileSystem } from "../vfs/VirtualFileSystem"; import { compile } from "./compile"; import { precompile } from "./precompile"; import { getCompilerVersion } from "./version"; -import { idText } from "../grammar/ast"; +import { FactoryAst, getAstFactory, idText } from "../grammar/ast"; import { TactErrorCollection } from "../errors"; +import { getParser, Parser } from "../grammar"; export function enableFeatures( ctx: CompilerContext, @@ -52,12 +53,16 @@ export async function build(args: { project: VirtualFileSystem; stdlib: string | VirtualFileSystem; logger?: ILogger; + parser?: Parser; + ast?: FactoryAst; }): Promise<{ ok: boolean; error: TactErrorCollection[] }> { const { config, project } = args; const stdlib = typeof args.stdlib === "string" ? createVirtualFileSystem(args.stdlib, files) : args.stdlib; + const ast: FactoryAst = args.ast ?? getAstFactory(); + const parser: Parser = args.parser ?? getParser(ast); const logger: ILogger = args.logger ?? new Logger(); // Configure context @@ -70,7 +75,7 @@ export async function build(args: { // Precompile try { - ctx = precompile(ctx, project, stdlib, config.path); + ctx = precompile(ctx, project, stdlib, config.path, parser, ast); } catch (e) { logger.error( config.mode === "checkOnly" || config.mode === "funcOnly" diff --git a/src/pipeline/precompile.ts b/src/pipeline/precompile.ts index 8ada419da..8695f5c14 100644 --- a/src/pipeline/precompile.ts +++ b/src/pipeline/precompile.ts @@ -7,24 +7,27 @@ import { resolveErrors } from "../types/resolveErrors"; import { resolveSignatures } from "../types/resolveSignatures"; import { resolveImports } from "../imports/resolveImports"; import { VirtualFileSystem } from "../vfs/VirtualFileSystem"; -import { AstModule } from "../grammar/ast"; +import { AstModule, FactoryAst } from "../grammar/ast"; +import { Parser } from "../grammar"; export function precompile( ctx: CompilerContext, project: VirtualFileSystem, stdlib: VirtualFileSystem, entrypoint: string, + parser: Parser, + ast: FactoryAst, parsedModules?: AstModule[], ) { // Load all sources - const imported = resolveImports({ entrypoint, project, stdlib }); + const imported = resolveImports({ entrypoint, project, stdlib, parser }); // Add information about all the source code entries to the context - ctx = openContext(ctx, imported.tact, imported.func, parsedModules); + ctx = openContext(ctx, imported.tact, imported.func, parser, parsedModules); // First load type descriptors and check that // they all have valid signatures - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); // This creates TLB-style type definitions ctx = resolveSignatures(ctx); diff --git a/src/storage/resolveAllocation.spec.ts b/src/storage/resolveAllocation.spec.ts index 4af8847df..f4f823556 100644 --- a/src/storage/resolveAllocation.spec.ts +++ b/src/storage/resolveAllocation.spec.ts @@ -1,5 +1,4 @@ import fs from "fs"; -import { __DANGER_resetNodeId } from "../grammar/ast"; import { resolveDescriptors } from "../types/resolveDescriptors"; import { getAllocations, resolveAllocations } from "./resolveAllocation"; import { openContext } from "../grammar/store"; @@ -7,6 +6,8 @@ import { resolveStatements } from "../types/resolveStatements"; import { CompilerContext } from "../context"; import { resolveSignatures } from "../types/resolveSignatures"; import path from "path"; +import { getParser } from "../grammar"; +import { getAstFactory } from "../grammar/ast"; const stdlibPath = path.resolve(__dirname, "../../stdlib/std/primitives.tact"); const stdlib = fs.readFileSync(stdlibPath, "utf-8"); @@ -61,10 +62,8 @@ contract Sample { `; describe("resolveAllocation", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); it("should write program", () => { + const ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [ @@ -72,8 +71,9 @@ describe("resolveAllocation", () => { { code: src, path: "", origin: "user" }, ], [], + getParser(ast), ); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, ast); ctx = resolveSignatures(ctx); ctx = resolveStatements(ctx); ctx = resolveAllocations(ctx); diff --git a/src/test/compare.spec.ts b/src/test/compare.spec.ts index d8dd63bcb..46701de5b 100644 --- a/src/test/compare.spec.ts +++ b/src/test/compare.spec.ts @@ -1,10 +1,10 @@ import fs from "fs"; -import { __DANGER_resetNodeId } from "../grammar/ast"; -import { parse } from "../grammar/grammar"; +import { getParser } from "../grammar"; import { join } from "path"; import { AstComparator } from "../grammar/compare"; import { CONTRACTS_DIR } from "./util"; import * as assert from "assert"; +import { getAstFactory } from "../grammar/ast"; describe("comparator", () => { it.each(fs.readdirSync(CONTRACTS_DIR, { withFileTypes: true }))( @@ -15,6 +15,8 @@ describe("comparator", () => { } const filePath = join(CONTRACTS_DIR, dentry.name); const src = fs.readFileSync(filePath, "utf-8"); + const Ast = getAstFactory(); + const { parse } = getParser(Ast); const ast1 = parse(src, filePath, "user"); const ast2 = parse(src, filePath, "user"); assert.strictEqual( diff --git a/src/test/compilation-failed/abi-global-errors.spec.ts b/src/test/compilation-failed/abi-global-errors.spec.ts index 3f9808dfc..f1169b512 100644 --- a/src/test/compilation-failed/abi-global-errors.spec.ts +++ b/src/test/compilation-failed/abi-global-errors.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("abi/global.ts errors", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "sha256-expects-string-or-slice", errorMessage: "sha256 expects string or slice argument", diff --git a/src/test/compilation-failed/const-eval-failed.spec.ts b/src/test/compilation-failed/const-eval-failed.spec.ts index f84d1c7e3..61f60ab21 100644 --- a/src/test/compilation-failed/const-eval-failed.spec.ts +++ b/src/test/compilation-failed/const-eval-failed.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("fail-const-eval", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "const-eval-div-by-zero", errorMessage: diff --git a/src/test/compilation-failed/contract-duplicate-opcodes.spec.ts b/src/test/compilation-failed/contract-duplicate-opcodes.spec.ts index a799e930c..96f69139f 100644 --- a/src/test/compilation-failed/contract-duplicate-opcodes.spec.ts +++ b/src/test/compilation-failed/contract-duplicate-opcodes.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("contract-duplicate-opcodes", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "contract-duplicate-bounced-opcode", errorMessage: diff --git a/src/test/compilation-failed/func-errors.spec.ts b/src/test/compilation-failed/func-errors.spec.ts index cacceea19..5c142432d 100644 --- a/src/test/compilation-failed/func-errors.spec.ts +++ b/src/test/compilation-failed/func-errors.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("func-errors", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "func-function-does-not-exist", errorMessage: diff --git a/src/test/compilation-failed/scope-errors.spec.ts b/src/test/compilation-failed/scope-errors.spec.ts index 2fcc16b65..c9384d442 100644 --- a/src/test/compilation-failed/scope-errors.spec.ts +++ b/src/test/compilation-failed/scope-errors.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("scope-errors", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "scope-const-shadows-stdlib-ident", errorMessage: diff --git a/src/test/compilation-failed/stdlib-bugs.spec.ts b/src/test/compilation-failed/stdlib-bugs.spec.ts index da8cc72e6..9a2ffc8b2 100644 --- a/src/test/compilation-failed/stdlib-bugs.spec.ts +++ b/src/test/compilation-failed/stdlib-bugs.spec.ts @@ -1,11 +1,6 @@ -import { __DANGER_resetNodeId } from "../../grammar/ast"; import { itShouldNotCompile } from "./util"; describe("stdlib-bugs", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); - itShouldNotCompile({ testName: "stdlib-skipBits", errorMessage: 'Type mismatch: "" is not assignable to "Slice"', diff --git a/src/test/prettyPrinter.spec.ts b/src/test/prettyPrinter.spec.ts index 79bdb13f6..be2e7067e 100644 --- a/src/test/prettyPrinter.spec.ts +++ b/src/test/prettyPrinter.spec.ts @@ -1,11 +1,11 @@ import fs from "fs"; -import { __DANGER_resetNodeId } from "../grammar/ast"; import { prettyPrint } from "../prettyPrinter"; -import { parse } from "../grammar/grammar"; +import { getParser } from "../grammar"; import { join } from "path"; import { trimTrailingCR, CONTRACTS_DIR } from "./util"; import * as assert from "assert"; import JSONBig from "json-bigint"; +import { getAstFactory } from "../grammar/ast"; describe("formatter", () => { it.each(fs.readdirSync(CONTRACTS_DIR, { withFileTypes: true }))( @@ -14,6 +14,8 @@ describe("formatter", () => { if (!dentry.isFile()) { return; } + const Ast = getAstFactory(); + const { parse } = getParser(Ast); const filePath = join(CONTRACTS_DIR, dentry.name); const src = trimTrailingCR(fs.readFileSync(filePath, "utf-8")); const ast = parse(src, filePath, "user"); @@ -34,6 +36,8 @@ describe("formatter", () => { if (!dentry.isFile()) { return; } + const Ast = getAstFactory(); + const { parse } = getParser(Ast); const filePath = join(CONTRACTS_DIR, dentry.name); const src = fs.readFileSync(filePath, "utf-8"); const ast = parse(src, filePath, "user"); diff --git a/src/test/rename.spec.ts b/src/test/rename.spec.ts index 297549376..31e3fb6f5 100644 --- a/src/test/rename.spec.ts +++ b/src/test/rename.spec.ts @@ -1,11 +1,11 @@ import fs from "fs"; -import { __DANGER_resetNodeId } from "../grammar/ast"; -import { parse } from "../grammar/grammar"; import { join } from "path"; import { AstRenamer } from "../grammar/rename"; import { prettyPrint } from "../prettyPrinter"; import { trimTrailingCR, CONTRACTS_DIR } from "./util"; import * as assert from "assert"; +import { getParser } from "../grammar"; +import { getAstFactory } from "../grammar/ast"; const EXPECTED_DIR = join(CONTRACTS_DIR, "renamer-expected"); @@ -16,6 +16,8 @@ describe("renamer", () => { if (!dentry.isFile()) { return; } + const ast = getAstFactory(); + const { parse } = getParser(ast); const expectedFilePath = join(EXPECTED_DIR, dentry.name); const expected = fs.readFileSync(expectedFilePath, "utf-8"); const filePath = join(CONTRACTS_DIR, dentry.name); diff --git a/src/types/resolveABITypeRef.ts b/src/types/resolveABITypeRef.ts index 79713ae47..098f3fdfa 100644 --- a/src/types/resolveABITypeRef.ts +++ b/src/types/resolveABITypeRef.ts @@ -12,7 +12,6 @@ import { isSlice, isString, isStringBuilder, - SrcInfo, } from "../grammar/ast"; import { idTextErr, @@ -22,6 +21,7 @@ import { import { TypeRef } from "./types"; import { CompilerContext } from "../context"; import { getType } from "./resolveDescriptors"; +import { SrcInfo } from "../grammar"; type FormatDef = Record< string, diff --git a/src/types/resolveDescriptors.spec.ts b/src/types/resolveDescriptors.spec.ts index 74a34e0c2..ac966baf9 100644 --- a/src/types/resolveDescriptors.spec.ts +++ b/src/types/resolveDescriptors.spec.ts @@ -5,10 +5,11 @@ import { resolveDescriptors, } from "./resolveDescriptors"; import { resolveSignatures } from "./resolveSignatures"; -import { SrcInfo, __DANGER_resetNodeId } from "../grammar/ast"; import { loadCases } from "../utils/loadCases"; import { openContext } from "../grammar/store"; import { featureEnable } from "../config/features"; +import { getParser, SrcInfo } from "../grammar"; +import { getAstFactory } from "../grammar/ast"; expect.addSnapshotSerializer({ test: (src) => src instanceof SrcInfo, @@ -16,18 +17,17 @@ expect.addSnapshotSerializer({ }); describe("resolveDescriptors", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); for (const r of loadCases(__dirname + "/test/")) { it("should resolve descriptors for " + r.name, () => { + const Ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: r.code, path: "", origin: "user" }], [], + getParser(Ast), ); ctx = featureEnable(ctx, "external"); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, Ast); ctx = resolveSignatures(ctx); expect(getAllTypes(ctx)).toMatchSnapshot(); expect(getAllStaticFunctions(ctx)).toMatchSnapshot(); @@ -35,14 +35,16 @@ describe("resolveDescriptors", () => { } for (const r of loadCases(__dirname + "/test-failed/")) { it("should fail descriptors for " + r.name, () => { + const Ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: r.code, path: "", origin: "user" }], [], + getParser(Ast), ); ctx = featureEnable(ctx, "external"); expect(() => { - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, Ast); ctx = resolveSignatures(ctx); }).toThrowErrorMatchingSnapshot(); }); diff --git a/src/types/resolveDescriptors.ts b/src/types/resolveDescriptors.ts index a334b5cd5..bd5976596 100644 --- a/src/types/resolveDescriptors.ts +++ b/src/types/resolveDescriptors.ts @@ -5,7 +5,6 @@ import { AstNativeFunctionDecl, AstNode, AstType, - createAstNode, idText, AstId, eqNames, @@ -18,6 +17,7 @@ import { AstMapType, AstTypeId, AstAsmFunctionDef, + FactoryAst, } from "../grammar/ast"; import { traverse } from "../grammar/iterators"; import { @@ -49,7 +49,7 @@ import { resolveABIType, intMapFormats } from "./resolveABITypeRef"; import { enabledExternals } from "../config/features"; import { isRuntimeType } from "./isRuntimeType"; import { GlobalFunctions } from "../abi/global"; -import { ItemOrigin } from "../grammar/grammar"; +import { ItemOrigin } from "../grammar"; import { getExpType, resolveExpression } from "./resolveExpression"; import { emptyContext } from "./resolveStatements"; import { isAssignable } from "./subtyping"; @@ -264,7 +264,7 @@ function uidForName(name: string, types: Map) { return uid; } -export function resolveDescriptors(ctx: CompilerContext) { +export function resolveDescriptors(ctx: CompilerContext, Ast: FactoryAst) { const types: Map = new Map(); const staticFunctions: Map = new Map(); const staticConstants: Map = new Map(); @@ -1422,7 +1422,7 @@ export function resolveDescriptors(ctx: CompilerContext) { if (!t.init) { t.init = { params: [], - ast: createAstNode({ + ast: Ast.createNode({ kind: "contract_init", params: [], statements: [], @@ -1626,7 +1626,7 @@ export function resolveDescriptors(ctx: CompilerContext) { name: contractOrTrait.name, optional: false, }, - ast: cloneNode(traitFunction.ast), + ast: cloneNode(traitFunction.ast, Ast), }); } @@ -1697,7 +1697,7 @@ export function resolveDescriptors(ctx: CompilerContext) { // Register constant contractOrTrait.constants.push({ ...traitConstant, - ast: cloneNode(traitConstant.ast), + ast: cloneNode(traitConstant.ast, Ast), }); } @@ -1764,7 +1764,7 @@ export function resolveDescriptors(ctx: CompilerContext) { } contractOrTrait.receivers.push({ selector: f.selector, - ast: cloneNode(f.ast), + ast: cloneNode(f.ast, Ast), }); } diff --git a/src/types/resolveSignatures.ts b/src/types/resolveSignatures.ts index aedd0f9b0..807c63300 100644 --- a/src/types/resolveSignatures.ts +++ b/src/types/resolveSignatures.ts @@ -17,7 +17,7 @@ import { throwCompilationError } from "../errors"; import { AstNumber, AstReceiver } from "../grammar/ast"; import { commentPseudoOpcode } from "../generator/writers/writeRouter"; import { sha256_sync } from "@ton/crypto"; -import { dummySrcInfo } from "../grammar/grammar"; +import { dummySrcInfo } from "../grammar"; import { ensureInt } from "../interpreter"; import { evalConstantExpression } from "../constEval"; diff --git a/src/types/resolveStatements.spec.ts b/src/types/resolveStatements.spec.ts index ab0d59d2a..fa8539bda 100644 --- a/src/types/resolveStatements.spec.ts +++ b/src/types/resolveStatements.spec.ts @@ -1,38 +1,40 @@ import { getAllExpressionTypes } from "./resolveExpression"; import { resolveDescriptors } from "./resolveDescriptors"; import { loadCases } from "../utils/loadCases"; -import { __DANGER_resetNodeId } from "../grammar/ast"; import { openContext } from "../grammar/store"; import { resolveStatements } from "./resolveStatements"; import { CompilerContext } from "../context"; import { featureEnable } from "../config/features"; +import { getParser } from "../grammar"; +import { getAstFactory } from "../grammar/ast"; describe("resolveStatements", () => { - beforeEach(() => { - __DANGER_resetNodeId(); - }); for (const r of loadCases(__dirname + "/stmts/")) { it("should resolve statements for " + r.name, () => { + const Ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: r.code, path: "", origin: "user" }], [], + getParser(Ast), ); ctx = featureEnable(ctx, "external"); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, Ast); ctx = resolveStatements(ctx); expect(getAllExpressionTypes(ctx)).toMatchSnapshot(); }); } for (const r of loadCases(__dirname + "/stmts-failed/")) { it("should fail statements for " + r.name, () => { + const Ast = getAstFactory(); let ctx = openContext( new CompilerContext(), [{ code: r.code, path: "", origin: "user" }], [], + getParser(Ast), ); ctx = featureEnable(ctx, "external"); - ctx = resolveDescriptors(ctx); + ctx = resolveDescriptors(ctx, Ast); expect(() => resolveStatements(ctx)).toThrowErrorMatchingSnapshot(); }); } diff --git a/src/types/resolveStatements.ts b/src/types/resolveStatements.ts index 760c4b950..e5ca22bbf 100644 --- a/src/types/resolveStatements.ts +++ b/src/types/resolveStatements.ts @@ -1,7 +1,6 @@ import { CompilerContext } from "../context"; import { AstCondition, - SrcInfo, AstStatement, tryExtractPath, AstId, @@ -31,6 +30,7 @@ import { FunctionDescription, printTypeRef, TypeRef } from "./types"; import { evalConstantExpression } from "../constEval"; import { ensureInt } from "../interpreter"; import { crc16 } from "../utils/crc16"; +import { SrcInfo } from "../grammar"; export type StatementContext = { root: SrcInfo; diff --git a/src/types/types.ts b/src/types/types.ts index 8b34e74cf..4bcdd65fa 100644 --- a/src/types/types.ts +++ b/src/types/types.ts @@ -6,7 +6,6 @@ import { AstContractInit, AstNativeFunctionDecl, AstReceiver, - SrcInfo, AstTypeDecl, AstId, AstFunctionDecl, @@ -15,7 +14,7 @@ import { AstAsmFunctionDef, AstNumber, } from "../grammar/ast"; -import { dummySrcInfo, ItemOrigin } from "../grammar/grammar"; +import { dummySrcInfo, ItemOrigin, SrcInfo } from "../grammar"; export type TypeDescription = { kind: "struct" | "primitive_type_decl" | "contract" | "trait";