diff --git a/Cargo.lock b/Cargo.lock index 80df0b2a08..25536d5658 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,6 +161,9 @@ dependencies = [ [[package]] name = "qsc_ast" version = "0.0.0" +dependencies = [ + "num-bigint", +] [[package]] name = "qsc_codegen" diff --git a/compiler/qsc_ast/Cargo.toml b/compiler/qsc_ast/Cargo.toml index 20dd040a34..ec9d05794f 100644 --- a/compiler/qsc_ast/Cargo.toml +++ b/compiler/qsc_ast/Cargo.toml @@ -4,3 +4,4 @@ version = "0.0.0" edition = "2021" [dependencies] +num-bigint = "0.4.3" diff --git a/compiler/qsc_ast/src/ast.rs b/compiler/qsc_ast/src/ast.rs new file mode 100644 index 0000000000..0c5a82eb80 --- /dev/null +++ b/compiler/qsc_ast/src/ast.rs @@ -0,0 +1,603 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! The abstract syntax tree (AST) for Q#. The AST directly corresponds to the surface syntax of Q#. + +#![warn(missing_docs)] + +use num_bigint::BigInt; + +/// The unique identifier for an AST node. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct NodeId(u32); + +impl NodeId { + /// The ID for the root node in an AST. + pub const ROOT: Self = Self(0); + + /// The next ID in the sequence. + #[must_use] + pub fn next(&self) -> Self { + Self(self.0 + 1) + } +} + +/// A region between two source code positions. The offsets are absolute within an AST given that +/// each file has its own offset. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct Span { + /// The first byte offset. + pub lo: u32, + /// The last byte offset. + pub hi: u32, +} + +/// The package currently being compiled and the root node of an AST. +#[derive(Clone, Debug, PartialEq)] +pub struct Package { + /// The node ID. + pub id: NodeId, + /// The namespaces in the package. + pub namespaces: Vec, +} + +/// A namespace. +#[derive(Clone, Debug, PartialEq)] +pub struct Namespace { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The namespace name. + pub name: Path, + /// The items in the namespace. + pub items: Vec, +} + +/// A namespace item. +#[derive(Clone, Debug, PartialEq)] +pub struct Item { + /// The ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The item kind. + pub kind: ItemKind, +} + +/// A namespace item kind. +#[derive(Clone, Debug, PartialEq)] +pub enum ItemKind { + /// An `open` statement for another namespace with an optional alias. + Open(Ident, Option), + /// A `newtype` declaration. + Type(DeclMeta, Ident, TyDef), + /// A `function` or `operation` declaration. + Callable(DeclMeta, CallableDecl), +} + +/// Metadata for a top-level declaration. +#[derive(Clone, Debug, PartialEq)] +pub struct DeclMeta { + /// The declaration attributes. + pub attrs: Vec, + /// The declaration visibility. + pub visibility: Visibility, +} + +/// A visibility modifier. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct Visibility { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The visibility kind. + pub kind: VisibilityKind, +} + +/// An attribute. +#[derive(Clone, Debug, PartialEq)] +pub struct Attr { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The name of the attribute. + pub name: Path, + /// The argument to the attribute. + pub arg: Expr, +} + +/// A type definition. +#[derive(Clone, Debug, PartialEq)] +pub enum TyDef { + /// A field definition with an optional name but required type. + Field(Option, Ty), + /// A tuple. + Tuple(Vec), +} + +/// A callable declaration header. +#[derive(Clone, Debug, PartialEq)] +pub struct CallableDecl { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The callable kind. + pub kind: CallableKind, + /// The name of the callable. + pub name: Ident, + /// The type parameters to the callable. + pub ty_params: Vec, + /// The input to the callable. + pub input: Pat, + /// The return type of the callable. + pub output: Ty, + /// The functors supported by the callable. + pub functors: FunctorExpr, + /// The body of the callable. + pub body: CallableBody, +} + +/// The body of a callable. +#[derive(Clone, Debug, PartialEq)] +pub enum CallableBody { + /// A block for the callable's body specialization. + Block(Block), + /// One or more explicit specializations. + Specs(Vec), +} + +/// A specialization declaration. +#[derive(Clone, Debug, PartialEq)] +pub struct SpecDecl { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// Which specialization is being declared. + pub spec: Spec, + /// The body of the specialization. + pub body: SpecBody, +} + +/// The body of a specialization. +#[derive(Clone, Debug, PartialEq)] +pub enum SpecBody { + /// The strategy to use to automatically generate the specialization. + Gen(SpecGen), + /// A manual implementation of the specialization. + Impl(Pat, Block), +} + +/// An expression that describes a set of functors. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum FunctorExpr { + /// A binary operation. + BinOp(SetOp, Box, Box), + /// A literal for a specific functor. + Lit(Functor), + /// The empty set. + Null, +} + +/// A type. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Ty { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The type kind. + pub kind: TyKind, +} + +/// A type kind. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum TyKind { + /// One or more type arguments applied to a type constructor. + App(Box, Vec), + /// An arrow type: `->` for a function or `=>` for an operation. + Arrow(CallableKind, Box, Box, FunctorExpr), + /// An unspecified type, `_`, which may be inferred. + Hole, + /// A named type. + Path(Path), + /// A primitive type. + Prim(TyPrim), + /// A tuple type. + Tuple(Vec), + /// A type variable. + Var(TyVar), +} + +/// A sequenced block of statements. +#[derive(Clone, Debug, PartialEq)] +pub struct Block { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The statements in the block. + pub stmts: Vec, +} + +/// A statement. +#[derive(Clone, Debug, PartialEq)] +pub struct Stmt { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The statement kind. + pub kind: StmtKind, +} + +/// A statement kind. +#[derive(Clone, Debug, PartialEq)] +pub enum StmtKind { + /// A borrowed qubit binding: `borrow a = b;`. + Borrow(Pat, QubitInit, Option), + /// An expression without a trailing semicolon. + Expr(Expr), + /// A let binding: `let a = b;`. + Let(Pat, Expr), + /// A mutable binding: `mutable a = b;`. + Mutable(Pat, Expr), + /// An expression with a trailing semicolon. + Semi(Expr), + /// A fresh qubit binding: `use a = b;`. + Use(Pat, QubitInit, Option), +} + +/// An expression. +#[derive(Clone, Debug, PartialEq)] +pub struct Expr { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The expression kind. + pub kind: ExprKind, +} + +/// An expression kind. +#[derive(Clone, Debug, PartialEq)] +pub enum ExprKind { + /// An array: `[a, b, c]`. + Array(Vec), + /// An array constructed by repeating a value: `[a, size = b]`. + ArrayRepeat(Box, Box), + /// An assignment: `set a = b`. + Assign(Box, Box), + /// An assignment with a compound operator. For example: `set a += b`. + AssignOp(BinOp, Box, Box), + /// An assignment with a compound update operator: `set a w/= b <- c`. + AssignUpdate(Box, Box, Box), + /// A binary operator. + BinOp(BinOp, Box, Box), + /// A block: `{ ... }`. + Block(Block), + /// A call: `a(b)`. + Call(Box, Box), + /// A conjugation: `within { ... } apply { ... }`. + Conjugate(Block, Block), + /// A failure: `fail "message"`. + Fail(Box), + /// A field accessor: `a::F`. + Field(Box, Ident), + /// A for loop: `for a in b { ... }`. + For(Pat, Box, Block), + /// An unspecified expression, _, which may indicate partial application or a typed hole. + Hole, + /// An if expression, with an arbitrary number of elifs and an optional else: + /// `if a { ... } elif b { ... } else { ... }`. + If(Vec<(Expr, Block)>, Option), + /// An index accessor: `a[b]`. + Index(Box, Box), + /// An interpolated string: `$"{a} {b} {c}"`. + Interp(String, Vec), + /// A lambda: `a -> b` for a function and `a => b` for an operation. + Lambda(CallableKind, Pat, Box), + /// A literal. + Lit(Lit), + /// Parentheses: `(a)`. + Paren(Box), + /// A path: `a` or `a.b`. + Path(Path), + /// A range: `a..b..c`. + Range(Box, Box, Box), + /// A repeat-until loop with an optional fixup: `repeat { ... } until a fixup { ... }`. + Repeat(Block, Box, Option), + /// A return: `return a`. + Return(Box), + /// A ternary operator. + TernOp(TernOp, Box, Box, Box), + /// A tuple: `(a, b, c)`. + Tuple(Vec), + /// A unary operator. + UnOp(UnOp, Box), + /// A while loop: `while a { ... }`. + While(Box, Block), +} + +/// A pattern. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Pat { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The pattern kind. + pub kind: PatKind, +} + +/// A pattern kind. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum PatKind { + /// A binding with an optional type annotation. + Bind(Ident, Option), + /// A discarded binding, `_`, with an optional type annotation. + Discard(Option), + /// An elided pattern, `...`, used by specializations. + Elided, + /// A tuple: `(a, b, c)`. + Tuple(Vec), +} + +/// A qubit initializer. +#[derive(Clone, Debug, PartialEq)] +pub struct QubitInit { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The qubit initializer kind. + pub kind: QubitInitKind, +} + +/// A qubit initializer kind. +#[derive(Clone, Debug, PartialEq)] +pub enum QubitInitKind { + /// A single qubit: `Qubit()`. + Single, + /// A tuple: `(a, b, c)`. + Tuple(Vec), + /// An array of qubits: `Qubit[a]`. + Array(Box), +} + +/// A path to a declaration. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Path { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The namespace. + pub namespace: Option, + /// The declaration name. + pub name: Ident, +} + +/// An identifier. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Ident { + /// The node ID. + pub id: NodeId, + /// The span. + pub span: Span, + /// The identifier name. + pub name: String, +} + +/// A declaration visibility kind. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum VisibilityKind { + /// Visible everywhere. + Public, + /// Visible within a package. + Internal, +} + +/// A callable kind. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum CallableKind { + /// A function. + Function, + /// An operation. + Operation, +} + +/// A primitive type. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum TyPrim { + /// The array type. + Array, + /// The big integer type. + BigInt, + /// The boolean type. + Bool, + /// The floating-point type. + Double, + /// The integer type. + Int, + /// The Pauli operator type. + Pauli, + /// The qubit type. + Qubit, + /// The range type. + Range, + /// The measurement result type. + Result, + /// The string type. + String, +} + +/// A type variable. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum TyVar { + /// A named variable. + Name(String), + /// A numeric variable. + Id(u32), +} + +/// A literal. +#[derive(Clone, Debug, PartialEq)] +pub enum Lit { + /// A big integer literal. + BigInt(BigInt), + /// A boolean literal. + Bool(bool), + /// A floating-point literal. + Double(f64), + /// An integer literal. + Int(u64), + /// A Pauli operator literal. + Pauli(Pauli), + /// A measurement result literal. + Result(Result), + /// A string literal. + String(String), +} + +/// A measurement result. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Result { + /// The zero eigenvalue. + Zero, + /// The one eigenvalue. + One, +} + +/// A Pauli operator. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Pauli { + /// The Pauli I operator. + I, + /// The Pauli X operator. + X, + /// The Pauli Y operator. + Y, + /// The Pauli Z operator. + Z, +} + +/// A functor that may be applied to an operation. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Functor { + /// The adjoint functor. + Adj, + /// The controlled functor. + Ctl, +} + +/// A specialization that may be implemented for an operation. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Spec { + /// The default specialization. + Body, + /// The adjoint specialization. + Adj, + /// The controlled specialization. + Ctl, + /// The controlled adjoint specialization. + CtlAdj, +} + +/// A strategy for generating a specialization. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum SpecGen { + /// Choose a strategy automatically. + Auto, + /// Distributes controlled qubits. + Distribute, + /// A specialization implementation is not generated, but is instead left as an opaque + /// declaration. + Intrinsic, + /// Inverts the order of operations. + Invert, + /// Uses the body specialization without modification. + Slf, +} + +/// A unary operator. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum UnOp { + /// A functor application. + Functor(Functor), + /// Negation: `-`. + Neg, + /// Bitwise NOT: `~~~`. + NotB, + /// Logical NOT: `not`. + NotL, + /// A leading `+`. + Pos, + /// Unwrap a user-defined type: `!`. + Unwrap, +} + +/// A binary operator. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum BinOp { + /// Addition: `+`. + Add, + /// Bitwise AND: `&&&`. + AndB, + /// Logical AND: `and`. + AndL, + /// Division: `/`. + Div, + /// Equality: `==`. + Eq, + /// Exponentiation: `^`. + Exp, + /// Greater than: `>`. + Gt, + /// Greater than or equal: `>=`. + Gte, + /// Less than: `<`. + Lt, + /// Less than or equal: `<=`. + Lte, + /// Modulus: `%`. + Mod, + /// Multiplication: `*`. + Mul, + /// Inequality: `!=`. + Neq, + /// Bitwise OR: `|||`. + OrB, + /// Logical OR: `or`. + OrL, + /// Shift left: `<<<`. + Shl, + /// Shift right: `>>>`. + Shr, + /// Subtraction: `-`. + Sub, + /// Bitwise XOR: `^^^`. + XorB, +} + +/// A ternary operator. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum TernOp { + /// Conditional: `a ? b | c`. + Cond, + /// Aggregate update: `a w/ b <- c`. + Update, +} + +/// A set operator. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum SetOp { + /// The set union. + Union, + /// The set intersection. + Intersect, +} diff --git a/compiler/qsc_ast/src/lib.rs b/compiler/qsc_ast/src/lib.rs index fc7aaf15df..12bc032cbb 100644 --- a/compiler/qsc_ast/src/lib.rs +++ b/compiler/qsc_ast/src/lib.rs @@ -3,18 +3,6 @@ #![warn(clippy::pedantic)] -#[must_use] -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +pub mod ast; +pub mod mut_visit; +pub mod visit; diff --git a/compiler/qsc_ast/src/mut_visit.rs b/compiler/qsc_ast/src/mut_visit.rs new file mode 100644 index 0000000000..4d8d5b7dd2 --- /dev/null +++ b/compiler/qsc_ast/src/mut_visit.rs @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::ast::{ + Attr, Block, CallableBody, CallableDecl, Expr, ExprKind, FunctorExpr, Ident, Item, ItemKind, + Namespace, Package, Pat, PatKind, Path, QubitInit, QubitInitKind, SpecBody, SpecDecl, Stmt, + StmtKind, Ty, TyDef, TyKind, +}; + +pub trait MutVisitor: Sized { + fn visit_package(&mut self, package: &mut Package) { + walk_package(self, package); + } + + fn visit_namespace(&mut self, namespace: &mut Namespace) { + walk_namespace(self, namespace); + } + + fn visit_item(&mut self, item: &mut Item) { + walk_item(self, item); + } + + fn visit_attr(&mut self, attr: &mut Attr) { + walk_attr(self, attr); + } + + fn visit_ty_def(&mut self, def: &mut TyDef) { + walk_ty_def(self, def); + } + + fn visit_callable_decl(&mut self, decl: &mut CallableDecl) { + walk_callable_decl(self, decl); + } + + fn visit_spec_decl(&mut self, decl: &mut SpecDecl) { + walk_spec_decl(self, decl); + } + + fn visit_functor_expr(&mut self, expr: &mut FunctorExpr) { + walk_functor_expr(self, expr); + } + + fn visit_ty(&mut self, ty: &mut Ty) { + walk_ty(self, ty); + } + + fn visit_block(&mut self, block: &mut Block) { + walk_block(self, block); + } + + fn visit_stmt(&mut self, stmt: &mut Stmt) { + walk_stmt(self, stmt); + } + + fn visit_expr(&mut self, expr: &mut Expr) { + walk_expr(self, expr); + } + + fn visit_pat(&mut self, pat: &mut Pat) { + walk_pat(self, pat); + } + + fn visit_qubit_init(&mut self, init: &mut QubitInit) { + walk_qubit_init(self, init); + } + + fn visit_path(&mut self, _: &mut Path) {} + + fn visit_ident(&mut self, _: &mut Ident) {} +} + +pub fn walk_package(vis: &mut impl MutVisitor, package: &mut Package) { + package + .namespaces + .iter_mut() + .for_each(|n| vis.visit_namespace(n)); +} + +pub fn walk_namespace(vis: &mut impl MutVisitor, namespace: &mut Namespace) { + vis.visit_path(&mut namespace.name); + namespace.items.iter_mut().for_each(|i| vis.visit_item(i)); +} + +pub fn walk_item(vis: &mut impl MutVisitor, item: &mut Item) { + match &mut item.kind { + ItemKind::Open(ns, alias) => { + vis.visit_ident(ns); + alias.iter_mut().for_each(|a| vis.visit_ident(a)); + } + ItemKind::Type(meta, ident, def) => { + meta.attrs.iter_mut().for_each(|a| vis.visit_attr(a)); + vis.visit_ident(ident); + vis.visit_ty_def(def); + } + ItemKind::Callable(meta, decl) => { + meta.attrs.iter_mut().for_each(|a| vis.visit_attr(a)); + vis.visit_callable_decl(decl); + } + } +} + +pub fn walk_attr(vis: &mut impl MutVisitor, attr: &mut Attr) { + vis.visit_path(&mut attr.name); + vis.visit_expr(&mut attr.arg); +} + +pub fn walk_ty_def(vis: &mut impl MutVisitor, def: &mut TyDef) { + match def { + TyDef::Field(name, ty) => { + name.iter_mut().for_each(|n| vis.visit_ident(n)); + vis.visit_ty(ty); + } + TyDef::Tuple(defs) => defs.iter_mut().for_each(|d| vis.visit_ty_def(d)), + } +} + +pub fn walk_callable_decl(vis: &mut impl MutVisitor, decl: &mut CallableDecl) { + vis.visit_ident(&mut decl.name); + decl.ty_params.iter_mut().for_each(|p| vis.visit_ident(p)); + vis.visit_pat(&mut decl.input); + vis.visit_ty(&mut decl.output); + vis.visit_functor_expr(&mut decl.functors); + match &mut decl.body { + CallableBody::Block(block) => vis.visit_block(block), + CallableBody::Specs(specs) => specs.iter_mut().for_each(|s| vis.visit_spec_decl(s)), + } +} + +pub fn walk_spec_decl(vis: &mut impl MutVisitor, decl: &mut SpecDecl) { + match &mut decl.body { + SpecBody::Gen(_) => {} + SpecBody::Impl(pat, block) => { + vis.visit_pat(pat); + vis.visit_block(block); + } + } +} + +pub fn walk_functor_expr(vis: &mut impl MutVisitor, expr: &mut FunctorExpr) { + match expr { + FunctorExpr::BinOp(_, lhs, rhs) => { + vis.visit_functor_expr(lhs); + vis.visit_functor_expr(rhs); + } + FunctorExpr::Lit(_) | FunctorExpr::Null => {} + } +} + +pub fn walk_ty(vis: &mut impl MutVisitor, ty: &mut Ty) { + match &mut ty.kind { + TyKind::App(ty, tys) => { + vis.visit_ty(ty); + tys.iter_mut().for_each(|t| vis.visit_ty(t)); + } + TyKind::Arrow(_, lhs, rhs, functors) => { + vis.visit_ty(lhs); + vis.visit_ty(rhs); + vis.visit_functor_expr(functors); + } + TyKind::Path(path) => vis.visit_path(path), + TyKind::Tuple(tys) => tys.iter_mut().for_each(|t| vis.visit_ty(t)), + TyKind::Hole | TyKind::Prim(_) | TyKind::Var(_) => {} + } +} + +pub fn walk_block(vis: &mut impl MutVisitor, block: &mut Block) { + block.stmts.iter_mut().for_each(|s| vis.visit_stmt(s)); +} + +pub fn walk_stmt(vis: &mut impl MutVisitor, stmt: &mut Stmt) { + match &mut stmt.kind { + StmtKind::Borrow(pat, init, block) | StmtKind::Use(pat, init, block) => { + vis.visit_pat(pat); + vis.visit_qubit_init(init); + block.iter_mut().for_each(|b| vis.visit_block(b)); + } + StmtKind::Expr(expr) | StmtKind::Semi(expr) => vis.visit_expr(expr), + StmtKind::Let(pat, value) | StmtKind::Mutable(pat, value) => { + vis.visit_pat(pat); + vis.visit_expr(value); + } + } +} + +pub fn walk_expr(vis: &mut impl MutVisitor, expr: &mut Expr) { + match &mut expr.kind { + ExprKind::Array(exprs) => exprs.iter_mut().for_each(|e| vis.visit_expr(e)), + ExprKind::ArrayRepeat(item, size) => { + vis.visit_expr(item); + vis.visit_expr(size); + } + ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) => { + vis.visit_expr(lhs); + vis.visit_expr(rhs); + } + ExprKind::AssignUpdate(record, index, value) => { + vis.visit_expr(record); + vis.visit_expr(index); + vis.visit_expr(value); + } + ExprKind::Block(block) => vis.visit_block(block), + ExprKind::Call(callee, arg) => { + vis.visit_expr(callee); + vis.visit_expr(arg); + } + ExprKind::Conjugate(within, apply) => { + vis.visit_block(within); + vis.visit_block(apply); + } + ExprKind::Fail(msg) => vis.visit_expr(msg), + ExprKind::Field(record, name) => { + vis.visit_expr(record); + vis.visit_ident(name); + } + ExprKind::For(pat, iter, block) => { + vis.visit_pat(pat); + vis.visit_expr(iter); + vis.visit_block(block); + } + ExprKind::If(branches, default) => { + for (cond, block) in branches { + vis.visit_expr(cond); + vis.visit_block(block); + } + default.iter_mut().for_each(|d| vis.visit_block(d)); + } + ExprKind::Index(array, index) => { + vis.visit_expr(array); + vis.visit_expr(index); + } + ExprKind::Interp(_, exprs) => exprs.iter_mut().for_each(|e| vis.visit_expr(e)), + ExprKind::Lambda(_, pat, expr) => { + vis.visit_pat(pat); + vis.visit_expr(expr); + } + ExprKind::Paren(expr) | ExprKind::Return(expr) | ExprKind::UnOp(_, expr) => { + vis.visit_expr(expr); + } + ExprKind::Path(path) => vis.visit_path(path), + ExprKind::Range(start, step, end) => { + vis.visit_expr(start); + vis.visit_expr(step); + vis.visit_expr(end); + } + ExprKind::Repeat(body, until, fixup) => { + vis.visit_block(body); + vis.visit_expr(until); + fixup.iter_mut().for_each(|f| vis.visit_block(f)); + } + ExprKind::TernOp(_, e1, e2, e3) => { + vis.visit_expr(e1); + vis.visit_expr(e2); + vis.visit_expr(e3); + } + ExprKind::Tuple(exprs) => exprs.iter_mut().for_each(|e| vis.visit_expr(e)), + ExprKind::While(cond, block) => { + vis.visit_expr(cond); + vis.visit_block(block); + } + ExprKind::Hole | ExprKind::Lit(_) => {} + } +} + +pub fn walk_pat(vis: &mut impl MutVisitor, pat: &mut Pat) { + match &mut pat.kind { + PatKind::Bind(name, ty) => { + vis.visit_ident(name); + ty.iter_mut().for_each(|t| vis.visit_ty(t)); + } + PatKind::Discard(ty) => ty.iter_mut().for_each(|t| vis.visit_ty(t)), + PatKind::Tuple(pats) => pats.iter_mut().for_each(|p| vis.visit_pat(p)), + PatKind::Elided => {} + } +} + +pub fn walk_qubit_init(vis: &mut impl MutVisitor, init: &mut QubitInit) { + match &mut init.kind { + QubitInitKind::Single => {} + QubitInitKind::Tuple(inits) => inits.iter_mut().for_each(|i| vis.visit_qubit_init(i)), + QubitInitKind::Array(len) => vis.visit_expr(len), + } +} diff --git a/compiler/qsc_ast/src/visit.rs b/compiler/qsc_ast/src/visit.rs new file mode 100644 index 0000000000..e9360f678e --- /dev/null +++ b/compiler/qsc_ast/src/visit.rs @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::ast::{ + Attr, Block, CallableBody, CallableDecl, Expr, ExprKind, FunctorExpr, Ident, Item, ItemKind, + Namespace, Package, Pat, PatKind, Path, QubitInit, QubitInitKind, SpecBody, SpecDecl, Stmt, + StmtKind, Ty, TyDef, TyKind, +}; + +pub trait Visitor: Sized { + fn visit_package(&mut self, package: &Package) { + walk_package(self, package); + } + + fn visit_namespace(&mut self, namespace: &Namespace) { + walk_namespace(self, namespace); + } + + fn visit_item(&mut self, item: &Item) { + walk_item(self, item); + } + + fn visit_attr(&mut self, attr: &Attr) { + walk_attr(self, attr); + } + + fn visit_ty_def(&mut self, def: &TyDef) { + walk_ty_def(self, def); + } + + fn visit_callable_decl(&mut self, decl: &CallableDecl) { + walk_callable_decl(self, decl); + } + + fn visit_spec_decl(&mut self, decl: &SpecDecl) { + walk_spec_decl(self, decl); + } + + fn visit_functor_expr(&mut self, expr: &FunctorExpr) { + walk_functor_expr(self, expr); + } + + fn visit_ty(&mut self, ty: &Ty) { + walk_ty(self, ty); + } + + fn visit_block(&mut self, block: &Block) { + walk_block(self, block); + } + + fn visit_stmt(&mut self, stmt: &Stmt) { + walk_stmt(self, stmt); + } + + fn visit_expr(&mut self, expr: &Expr) { + walk_expr(self, expr); + } + + fn visit_pat(&mut self, pat: &Pat) { + walk_pat(self, pat); + } + + fn visit_qubit_init(&mut self, init: &QubitInit) { + walk_qubit_init(self, init); + } + + fn visit_path(&mut self, _: &Path) {} + + fn visit_ident(&mut self, _: &Ident) {} +} + +pub fn walk_package(vis: &mut impl Visitor, package: &Package) { + package + .namespaces + .iter() + .for_each(|n| vis.visit_namespace(n)); +} + +pub fn walk_namespace(vis: &mut impl Visitor, namespace: &Namespace) { + vis.visit_path(&namespace.name); + namespace.items.iter().for_each(|i| vis.visit_item(i)); +} + +pub fn walk_item(vis: &mut impl Visitor, item: &Item) { + match &item.kind { + ItemKind::Open(ns, alias) => { + vis.visit_ident(ns); + alias.iter().for_each(|a| vis.visit_ident(a)); + } + ItemKind::Type(meta, ident, def) => { + meta.attrs.iter().for_each(|a| vis.visit_attr(a)); + vis.visit_ident(ident); + vis.visit_ty_def(def); + } + ItemKind::Callable(meta, decl) => { + meta.attrs.iter().for_each(|a| vis.visit_attr(a)); + vis.visit_callable_decl(decl); + } + } +} + +pub fn walk_attr(vis: &mut impl Visitor, attr: &Attr) { + vis.visit_path(&attr.name); + vis.visit_expr(&attr.arg); +} + +pub fn walk_ty_def(vis: &mut impl Visitor, def: &TyDef) { + match def { + TyDef::Field(name, ty) => { + name.iter().for_each(|n| vis.visit_ident(n)); + vis.visit_ty(ty); + } + TyDef::Tuple(defs) => defs.iter().for_each(|d| vis.visit_ty_def(d)), + } +} + +pub fn walk_callable_decl(vis: &mut impl Visitor, decl: &CallableDecl) { + vis.visit_ident(&decl.name); + decl.ty_params.iter().for_each(|p| vis.visit_ident(p)); + vis.visit_pat(&decl.input); + vis.visit_ty(&decl.output); + vis.visit_functor_expr(&decl.functors); + match &decl.body { + CallableBody::Block(block) => vis.visit_block(block), + CallableBody::Specs(specs) => specs.iter().for_each(|s| vis.visit_spec_decl(s)), + } +} + +pub fn walk_spec_decl(vis: &mut impl Visitor, decl: &SpecDecl) { + match &decl.body { + SpecBody::Gen(_) => {} + SpecBody::Impl(pat, block) => { + vis.visit_pat(pat); + vis.visit_block(block); + } + } +} + +pub fn walk_functor_expr(vis: &mut impl Visitor, expr: &FunctorExpr) { + match expr { + FunctorExpr::BinOp(_, lhs, rhs) => { + vis.visit_functor_expr(lhs); + vis.visit_functor_expr(rhs); + } + FunctorExpr::Lit(_) | FunctorExpr::Null => {} + } +} + +pub fn walk_ty(vis: &mut impl Visitor, ty: &Ty) { + match &ty.kind { + TyKind::App(ty, tys) => { + vis.visit_ty(ty); + tys.iter().for_each(|t| vis.visit_ty(t)); + } + TyKind::Arrow(_, lhs, rhs, functors) => { + vis.visit_ty(lhs); + vis.visit_ty(rhs); + vis.visit_functor_expr(functors); + } + TyKind::Path(path) => vis.visit_path(path), + TyKind::Tuple(tys) => tys.iter().for_each(|t| vis.visit_ty(t)), + TyKind::Hole | TyKind::Prim(_) | TyKind::Var(_) => {} + } +} + +pub fn walk_block(vis: &mut impl Visitor, block: &Block) { + block.stmts.iter().for_each(|s| vis.visit_stmt(s)); +} + +pub fn walk_stmt(vis: &mut impl Visitor, stmt: &Stmt) { + match &stmt.kind { + StmtKind::Borrow(pat, init, block) | StmtKind::Use(pat, init, block) => { + vis.visit_pat(pat); + vis.visit_qubit_init(init); + block.iter().for_each(|b| vis.visit_block(b)); + } + StmtKind::Expr(expr) | StmtKind::Semi(expr) => vis.visit_expr(expr), + StmtKind::Let(pat, value) | StmtKind::Mutable(pat, value) => { + vis.visit_pat(pat); + vis.visit_expr(value); + } + } +} + +pub fn walk_expr(vis: &mut impl Visitor, expr: &Expr) { + match &expr.kind { + ExprKind::Array(exprs) => exprs.iter().for_each(|e| vis.visit_expr(e)), + ExprKind::ArrayRepeat(item, size) => { + vis.visit_expr(item); + vis.visit_expr(size); + } + ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) => { + vis.visit_expr(lhs); + vis.visit_expr(rhs); + } + ExprKind::AssignUpdate(record, index, value) => { + vis.visit_expr(record); + vis.visit_expr(index); + vis.visit_expr(value); + } + ExprKind::Block(block) => vis.visit_block(block), + ExprKind::Call(callee, arg) => { + vis.visit_expr(callee); + vis.visit_expr(arg); + } + ExprKind::Conjugate(within, apply) => { + vis.visit_block(within); + vis.visit_block(apply); + } + ExprKind::Fail(msg) => vis.visit_expr(msg), + ExprKind::Field(record, name) => { + vis.visit_expr(record); + vis.visit_ident(name); + } + ExprKind::For(pat, iter, block) => { + vis.visit_pat(pat); + vis.visit_expr(iter); + vis.visit_block(block); + } + ExprKind::If(branches, default) => { + for (cond, block) in branches { + vis.visit_expr(cond); + vis.visit_block(block); + } + default.iter().for_each(|d| vis.visit_block(d)); + } + ExprKind::Index(array, index) => { + vis.visit_expr(array); + vis.visit_expr(index); + } + ExprKind::Interp(_, exprs) => exprs.iter().for_each(|e| vis.visit_expr(e)), + ExprKind::Lambda(_, pat, expr) => { + vis.visit_pat(pat); + vis.visit_expr(expr); + } + ExprKind::Paren(expr) | ExprKind::Return(expr) | ExprKind::UnOp(_, expr) => { + vis.visit_expr(expr); + } + ExprKind::Path(path) => vis.visit_path(path), + ExprKind::Range(start, step, end) => { + vis.visit_expr(start); + vis.visit_expr(step); + vis.visit_expr(end); + } + ExprKind::Repeat(body, until, fixup) => { + vis.visit_block(body); + vis.visit_expr(until); + fixup.iter().for_each(|f| vis.visit_block(f)); + } + ExprKind::TernOp(_, e1, e2, e3) => { + vis.visit_expr(e1); + vis.visit_expr(e2); + vis.visit_expr(e3); + } + ExprKind::Tuple(exprs) => exprs.iter().for_each(|e| vis.visit_expr(e)), + ExprKind::While(cond, block) => { + vis.visit_expr(cond); + vis.visit_block(block); + } + ExprKind::Hole | ExprKind::Lit(_) => {} + } +} + +pub fn walk_pat(vis: &mut impl Visitor, pat: &Pat) { + match &pat.kind { + PatKind::Bind(name, ty) => { + vis.visit_ident(name); + ty.iter().for_each(|t| vis.visit_ty(t)); + } + PatKind::Discard(ty) => ty.iter().for_each(|t| vis.visit_ty(t)), + PatKind::Tuple(pats) => pats.iter().for_each(|p| vis.visit_pat(p)), + PatKind::Elided => {} + } +} + +pub fn walk_qubit_init(vis: &mut impl Visitor, init: &QubitInit) { + match &init.kind { + QubitInitKind::Single => {} + QubitInitKind::Tuple(inits) => inits.iter().for_each(|i| vis.visit_qubit_init(i)), + QubitInitKind::Array(len) => vis.visit_expr(len), + } +}