Skip to content

Commit

Permalink
Merge pull request #3332 from quantified-uncertainty/michael-type-che…
Browse files Browse the repository at this point in the history
…cking

Unit type checking - #3323 + cleanups
  • Loading branch information
berekuk authored Jul 23, 2024
2 parents b0bdab8 + 58bffb6 commit 008fc0e
Show file tree
Hide file tree
Showing 24 changed files with 2,699 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ function* getMarkerSubData(
ast: ASTNode,
path: SqValuePath
): Generator<MarkerDatum, void> {
switch (ast.type) {
switch (ast.kind) {
case "Dict":
// Assuming 'elements' is an array of { key: ASTNode, value: ASTNode | string }
// and we only want to include ASTNode values
for (const element of ast.elements) {
if (element.type === "Identifier") continue;
if (element.key.type !== "String") continue;
if (element.kind === "Identifier") continue;
if (element.key.kind !== "String") continue;
const subPath = path.extend(SqValuePathEdge.fromKey(element.key.value));
yield { ast: element.key, path: subPath };
yield* getMarkerSubData(element.value, subPath);
Expand All @@ -54,14 +54,14 @@ function* getMarkerSubData(
}

function* getMarkerData(ast: ASTNode): Generator<MarkerDatum, void> {
if (ast.type !== "Program") {
if (ast.kind !== "Program") {
return; // unexpected
}

nextStatement: for (const statement of ast.statements) {
if (
statement.type === "DefunStatement" ||
statement.type === "LetStatement"
statement.kind === "DefunStatement" ||
statement.kind === "LetStatement"
) {
for (const decorator of statement.decorators) {
if (decorator.name.value === "hide") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ export function tooltipsExtension() {

if (
// Note that `valueAst` can't be "DecoratedStatement", we skip those in `SqValueContext` and AST symbols
(valueAst.type === "LetStatement" ||
valueAst.type === "DefunStatement") &&
(valueAst.kind === "LetStatement" ||
valueAst.kind === "DefunStatement") &&
// If these don't match then variable was probably shadowed by a later statement and we can't show its value.
// Or it could be caused by code rot, if we change the logic of how `valueAst` is computed, or add another statement type in AST.
// TODO - if we can prove that the variable was shadowed, show the tooltip pointing to the latest assignment.
Expand Down
6 changes: 3 additions & 3 deletions packages/hub/src/graphql/types/ModelRevision.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ const ModelRevisionBuildStatus = builder.enumType("ModelRevisionBuildStatus", {
function astToVariableNames(ast: ASTNode): string[] {
const exportedVariableNames: string[] = [];

if (ast.type === "Program") {
if (ast.kind === "Program") {
ast.statements.forEach((statement) => {
if (
(statement.type === "LetStatement" ||
statement.type === "DefunStatement") &&
(statement.kind === "LetStatement" ||
statement.kind === "DefunStatement") &&
statement.exported
) {
exportedVariableNames.push(statement.variable.value);
Expand Down
61 changes: 46 additions & 15 deletions packages/prettier-plugin/src/printer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function getNodePrecedence(node: SquiggleNode): number {
"->": 10,
"|>": 10, // removed since 0.8.0
};
switch (node.type) {
switch (node.kind) {
case "Ternary":
return 1;
case "InfixCall": {
Expand Down Expand Up @@ -84,7 +84,7 @@ export function createSquigglePrinter(
return needParens ? ["(", doc, ")"] : doc;
};

switch (node.type) {
switch (node.kind) {
case "Program":
// TODO - preserve line breaks, break long lines
// TODO - comments will be moved to the end because imports is not a real AST, need to be fixed in squiggle-lang
Expand All @@ -111,7 +111,7 @@ export function createSquigglePrinter(
),
node.statements.length &&
["LetStatement", "DefunStatement"].includes(
node.statements[node.statements.length - 1].type
node.statements[node.statements.length - 1].kind
)
? hardline // new line if final expression is a statement
: "",
Expand Down Expand Up @@ -167,6 +167,9 @@ export function createSquigglePrinter(
]),
node.exported ? "export " : "",
node.variable.value,
node.unitTypeSignature
? typedPath(node).call(print, "unitTypeSignature")
: "",
" = ",
typedPath(node).call(print, "value"),
]);
Expand All @@ -187,6 +190,10 @@ export function createSquigglePrinter(
softline,
")",
]),
node.value.returnUnitType
? // @ts-ignore
typedPath(node).call(print, "value", "returnUnitType")
: "",
" = ",
typedPath(node).call(print, "value", "body"),
]);
Expand Down Expand Up @@ -280,7 +287,13 @@ export function createSquigglePrinter(
"]",
]);
case "Identifier":
return node.value;
return group([
node.value,
node.unitTypeSignature
? // @ts-ignore
typedPath(node).call(print, "unitTypeSignature")
: "",
]);
case "IdentifierWithAnnotation":
return [
node.variable,
Expand All @@ -289,7 +302,7 @@ export function createSquigglePrinter(
];
case "KeyValue": {
const key =
node.key.type === "String" &&
node.key.kind === "String" &&
node.key.value.match(/^[\$_a-z]+[\$_a-zA-Z0-9]*$/)
? node.key.value
: typedPath(node).call(print, "key");
Expand All @@ -302,9 +315,9 @@ export function createSquigglePrinter(
]);
}
case "Lambda":
if (node.body.type === "Block") {
if (node.body.kind === "Block") {
(
node.body as Extract<PatchedASTNode, { type: "Block" }>
node.body as Extract<PatchedASTNode, { kind: "Block" }>
).isLambdaBody = true;
}
return group([
Expand All @@ -325,11 +338,15 @@ export function createSquigglePrinter(
]),
softline,
"}",
node.returnUnitType
? // @ts-ignore
typedPath(node).call(print, "returnUnitType")
: "",
]);
case "Dict": {
const isSingleKeyWithoutValue =
node.elements.length === 1 &&
node.elements[0].type === "Identifier";
node.elements[0].kind === "Identifier";
return group([
"{",
node.elements.length
Expand All @@ -349,13 +366,27 @@ export function createSquigglePrinter(
return [JSON.stringify(node.value).replaceAll("\\n", "\n")];
case "Ternary":
return [
node.kind === "C" ? [] : "if ",
node.syntax === "C" ? [] : "if ",
path.call(print, "condition"),
node.kind === "C" ? " ? " : " then ",
node.syntax === "C" ? " ? " : " then ",
path.call(print, "trueExpression"),
node.kind === "C" ? " : " : " else ",
node.syntax === "C" ? " : " : " else ",
path.call(print, "falseExpression"),
];
case "UnitTypeSignature":
return group([" :: ", typedPath(node).call(print, "body")]);
case "InfixUnitType":
return group([
typedPath(node).call(print, "args", 0),
node.op,
typedPath(node).call(print, "args", 1),
]);
case "ExponentialUnitType":
return group([
typedPath(node).call(print, "base"),
"^",
typedPath(node).call(print, "exponent"),
]);
case "UnitValue":
return [typedPath(node).call(print, "value"), node.unit];
case "lineComment":
Expand All @@ -365,7 +396,7 @@ export function createSquigglePrinter(
},
printComment: (path: AstPath<ASTCommentNode>) => {
const commentNode = path.node;
switch (commentNode.type) {
switch (commentNode.kind) {
case "lineComment":
// I'm not sure why "hardline" at the end here is not necessary
return ["//", commentNode.value];
Expand All @@ -376,14 +407,14 @@ export function createSquigglePrinter(
}
},
isBlockComment: (node) => {
return node.type === "blockComment";
return node.kind === "blockComment";
},
...({
getCommentChildNodes: (node: ASTNode) => {
if (!node) {
return [];
}
switch (node.type) {
switch (node.kind) {
case "Program":
return node.statements;
case "Block":
Expand All @@ -407,7 +438,7 @@ export function createSquigglePrinter(
}
},
canAttachComment: (node: ASTNode) => {
return node && node.type;
return node && node.kind;
},
} as any),
};
Expand Down
4 changes: 2 additions & 2 deletions packages/prettier-plugin/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { type ASTCommentNode, type ASTNode } from "@quri/squiggle-lang";

// This doesn't patch children types (e.g. `node.statements[0]` is `ASTNode`, not `PatchedASTNode`)
export type PatchedASTNode = (
| Exclude<ASTNode, { type: "Block" }>
| (Extract<ASTNode, { type: "Block" }> & { isLambdaBody?: boolean })
| Exclude<ASTNode, { kind: "Block" }>
| (Extract<ASTNode, { kind: "Block" }> & { isLambdaBody?: boolean })
) & {
comments?: ASTCommentNode[];
};
Expand Down
2 changes: 1 addition & 1 deletion packages/prettier-plugin/test/let.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ describe("let", () => {
expect(await format("f(x,y)=x*y")).toBe("f(x, y) = x * y\n");
});

test("defun with long args args", async () => {
test("defun with long args", async () => {
expect(
await format(
"f(yaewrtawieyra,auweyrauweyrauwyer,wekuryakwueyruaweyr,wekuryakwueyruaweyr,wekuryakwueyruaweyr)=x*y"
Expand Down
58 changes: 58 additions & 0 deletions packages/prettier-plugin/test/unit_type_signature.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { format } from "./helpers.js";

describe("unit type signature", () => {
describe("let", () => {
test("simple", async () => {
expect(await format("x::meters = 1")).toBe("x :: meters = 1\n");
});

test("multiplication", async () => {
expect(await format("x :: meters*meters = 2")).toBe("x :: meters*meters = 2\n");
});

test("division", async () => {
expect(await format("x :: meters/seconds = 3")).toBe("x :: meters/seconds = 3\n");
});

test("exponentiation", async () => {
expect(await format("x :: meters^2 = 4")).toBe("x :: meters^2 = 4\n");
});

test("complex type", async () => {
expect(await format("x :: kg*meters/seconds/seconds = 5")).toBe("x :: kg*meters/seconds/seconds = 5\n");
});

test("complex type with exponents", async () => {
expect(await format("x :: kg*meters^2/seconds^3 = 6")).toBe("x :: kg*meters^2/seconds^3 = 6\n");
});
});

describe("defun", () => {
test("with parameter types", async () => {
expect(await format("f(x ::euros, y:: pesos) = 11")).toBe("f(x :: euros, y :: pesos) = 11\n");
});

test("with return type", async () => {
expect(await format("f(x) :: dollars = 12")).toBe("f(x) :: dollars = 12\n");
});

test("with parameter and return types", async () => {
expect(await format("f(x ::euros, y:: euros) :: euros = 13")).toBe("f(x :: euros, y :: euros) :: euros = 13\n");
});
});

describe("lambda", () => {
test("with parameter types", async () => {
expect(await format("f = { |x :: unitOne, y :: unitTwo| x }")).toBe("f = {|x :: unitOne, y :: unitTwo| x}\n");
});

test("with return type", async () => {
expect(await format("{ |x :: inUnit| x } :: outUnit")).toBe("{|x :: inUnit| x} :: outUnit");
});

test("with return type followed by call", async () => {
expect(await format("{ |x :: inUnit| x } :: outUnit(23)")).toBe("{|x :: inUnit| x} :: outUnit(23)");
});
});

});
Loading

0 comments on commit 008fc0e

Please sign in to comment.