From 40bcfc6d3807d62e666300f5a90776226f989384 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Tue, 28 Nov 2023 14:59:38 -0800 Subject: [PATCH] [DSLX:frontend][NFC] Factor out Module/Proc from ast.h monolithic file. These are some of the leaf-most AST node types so they're more easily split out than some of the more interior nodes without creating circular deps. PiperOrigin-RevId: 586113553 --- xls/dslx/BUILD | 14 + xls/dslx/fmt/BUILD | 1 + xls/dslx/fmt/ast_fmt.cc | 1 + xls/dslx/frontend/BUILD | 45 +- xls/dslx/frontend/ast.cc | 475 +----------------- xls/dslx/frontend/ast.h | 370 -------------- xls/dslx/frontend/ast_cloner.cc | 2 + xls/dslx/frontend/ast_cloner_test.cc | 1 + xls/dslx/frontend/ast_test.cc | 1 + xls/dslx/frontend/ast_test_utils.cc | 1 + xls/dslx/frontend/ast_utils.cc | 40 ++ xls/dslx/frontend/ast_utils.h | 8 + xls/dslx/frontend/module.cc | 370 ++++++++++++++ xls/dslx/frontend/module.h | 268 ++++++++++ xls/dslx/frontend/parser.h | 1 + xls/dslx/frontend/proc.cc | 141 ++++++ xls/dslx/frontend/proc.h | 165 ++++++ xls/dslx/import_data.h | 1 + xls/dslx/interp_value_helpers_test.cc | 1 + xls/dslx/ir_convert/BUILD | 4 + .../ir_convert/extract_conversion_order.cc | 3 + .../ir_convert/extract_conversion_order.h | 1 + xls/dslx/lsp/BUILD | 3 + xls/dslx/lsp/document_symbols.h | 1 + xls/dslx/lsp/find_definition.cc | 4 + xls/dslx/lsp/find_definition.h | 1 + xls/dslx/run_comparator.cc | 15 + xls/dslx/type_system/BUILD | 6 + xls/dslx/type_system/concrete_type_test.cc | 2 + xls/dslx/type_system/type_info.cc | 1 + xls/dslx/type_system/type_info_test.cc | 1 + xls/dslx/type_system/typecheck.cc | 2 + xls/fuzzer/BUILD | 3 + xls/fuzzer/ast_generator.h | 3 +- xls/fuzzer/value_generator.cc | 1 + xls/fuzzer/value_generator_test.cc | 1 + xls/tools/BUILD | 2 + xls/tools/proto_to_dslx_main.cc | 1 + xls/tools/proto_to_dslx_test.cc | 1 + 39 files changed, 1116 insertions(+), 846 deletions(-) create mode 100644 xls/dslx/frontend/module.cc create mode 100644 xls/dslx/frontend/module.h create mode 100644 xls/dslx/frontend/proc.cc create mode 100644 xls/dslx/frontend/proc.h diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index a8a1fcee54..efbef8aa72 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -156,6 +156,7 @@ cc_test( "//xls/common:xls_gunit_main", "//xls/common/status:matchers", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/type_system:concrete_type", "//xls/ir:bits", @@ -272,6 +273,7 @@ cc_library( "//xls/common/status:ret_check", "//xls/dslx/bytecode:bytecode_cache_interface", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/type_system:type_info", "@com_google_absl//absl/container:flat_hash_map", @@ -404,11 +406,23 @@ cc_library( srcs = ["run_comparator.cc"], hdrs = ["run_comparator.h"], deps = [ + ":interp_value", ":mangle", ":run_routines", "//xls/common:test_macros", + "//xls/common/logging", + "//xls/common/status:ret_check", + "//xls/common/status:status_macros", + "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", + "//xls/dslx/type_system:parametric_env", "//xls/interpreter:ir_interpreter", + "//xls/ir", + "//xls/ir:value", "//xls/jit:function_jit", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) diff --git a/xls/dslx/fmt/BUILD b/xls/dslx/fmt/BUILD index bcb87a5511..e966836191 100644 --- a/xls/dslx/fmt/BUILD +++ b/xls/dslx/fmt/BUILD @@ -64,6 +64,7 @@ cc_library( "//xls/dslx:channel_direction", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:comment_data", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/frontend:token", "//xls/dslx/frontend:token_utils", diff --git a/xls/dslx/fmt/ast_fmt.cc b/xls/dslx/fmt/ast_fmt.cc index 5ed4987205..f9ee3fc22d 100644 --- a/xls/dslx/fmt/ast_fmt.cc +++ b/xls/dslx/fmt/ast_fmt.cc @@ -39,6 +39,7 @@ #include "xls/dslx/fmt/pretty_print.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/comment_data.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/frontend/token.h" #include "xls/dslx/frontend/token_utils.h" diff --git a/xls/dslx/frontend/BUILD b/xls/dslx/frontend/BUILD index 3cf97a8409..7958722202 100644 --- a/xls/dslx/frontend/BUILD +++ b/xls/dslx/frontend/BUILD @@ -27,6 +27,8 @@ cc_library( deps = [ ":ast", ":ast_utils", + ":module", + ":proc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -47,6 +49,7 @@ cc_test( deps = [ ":ast", ":ast_cloner", + ":module", "//xls/common:xls_gunit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", @@ -84,6 +87,7 @@ cc_library( ":ast_utils", ":bindings", ":builtins_metadata", + ":module", ":pos", ":scanner", ":token", @@ -134,6 +138,7 @@ cc_test( ], ) +# Note: ast_utils layers on top of the AST implementation. cc_library( name = "ast_utils", srcs = ["ast_utils.cc"], @@ -386,6 +391,43 @@ cc_library( ], ) +cc_library( + name = "proc", + srcs = ["proc.cc"], + hdrs = ["proc.h"], + deps = [ + ":ast", + ":pos", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "//xls/common:indent", + ], +) + +cc_library( + name = "module", + srcs = ["module.cc"], + hdrs = ["module.h"], + deps = [ + ":ast", + ":pos", + ":proc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "//xls/common:casts", + "//xls/common:visitor", + "//xls/common/logging", + "//xls/common/status:ret_check", + ], +) + cc_library( name = "ast_test_utils", testonly = True, @@ -393,6 +435,7 @@ cc_library( hdrs = ["ast_test_utils.h"], deps = [ ":ast", + ":module", ":pos", "@com_google_absl//absl/strings:str_format", "//xls/common/logging", @@ -405,8 +448,8 @@ cc_test( deps = [ ":ast", ":ast_test_utils", + ":module", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "//xls/common:xls_gunit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index 1c0d68632a..f1129e1019 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -43,6 +43,7 @@ #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" +#include "xls/dslx/channel_direction.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/frontend/token_utils.h" #include "xls/ir/bits.h" @@ -781,17 +782,6 @@ std::optional ColonRef::ResolveImportSubject() const { return import; } -// -- class ProcMember - -ProcMember::ProcMember(Module* owner, NameDef* name_def, - TypeAnnotation* type_annotation) - : AstNode(owner), - name_def_(name_def), - type_annotation_(type_annotation), - span_(name_def_->span().start(), type_annotation_->span().limit()) {} - -ProcMember::~ProcMember() = default; - // -- class Param Param::Param(Module* owner, NameDef* name_def, TypeAnnotation* type_annotation) @@ -822,329 +812,6 @@ std::string ChannelDecl::ToStringInternal() const { absl::StrJoin(dims, "")); } -// -- class Module - -Module::~Module() { - XLS_VLOG(3) << "Destroying module \"" << name_ << "\" @ " << this; -} - -std::string Module::ToString() const { - // Don't print Proc functions, as they'll be printed as part of the procs - // themselves. - std::vector print_top; - for (const auto& member : top_) { - if (std::holds_alternative(member) && - std::get(member)->proc().has_value()) { - continue; - } - print_top.push_back(member); - } - return absl::StrJoin(print_top, "\n", - [](std::string* out, const ModuleMember& member) { - absl::StrAppend(out, ToAstNode(member)->ToString()); - }); -} - -const AstNode* Module::FindNode(AstNodeKind kind, const Span& target) const { - for (const auto& node : nodes_) { - if (node->kind() == kind && node->GetSpan().has_value() && - node->GetSpan().value() == target) { - return node.get(); - } - } - return nullptr; -} - -std::vector Module::FindIntercepting(const Pos& target) const { - std::vector found; - for (const auto& node : nodes_) { - if (node->GetSpan().has_value() && node->GetSpan()->Contains(target)) { - found.push_back(node.get()); - } - } - return found; -} - -std::optional Module::GetFunction(std::string_view target_name) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - Function* f = std::get(member); - if (f->identifier() == target_name) { - return f; - } - } - } - return std::nullopt; -} - -std::optional Module::GetProc(std::string_view target_name) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - Proc* p = std::get(member); - if (p->identifier() == target_name) { - return p; - } - } - } - return std::nullopt; -} - -absl::StatusOr Module::GetTest(std::string_view target_name) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - TestFunction* t = std::get(member); - if (t->identifier() == target_name) { - return t; - } - } - } - return absl::NotFoundError(absl::StrFormat( - "No test in module %s with name \"%s\"", name_, target_name)); -} - -absl::StatusOr Module::GetTestProc(std::string_view target_name) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - auto* t = std::get(member); - if (t->proc()->identifier() == target_name) { - return t; - } - } - } - return absl::NotFoundError(absl::StrFormat( - "No test proc in module %s with name \"%s\"", name_, target_name)); -} - -std::vector Module::GetTestNames() const { - std::vector result; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - TestFunction* t = std::get(member); - result.push_back(t->identifier()); - } else if (std::holds_alternative(member)) { - TestProc* tp = std::get(member); - result.push_back(tp->proc()->identifier()); - } - } - return result; -} - -std::vector Module::GetFunctionNames() const { - std::vector result; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - result.push_back(std::get(member)->identifier()); - } - } - return result; -} - -const StructDef* Module::FindStructDef(const Span& span) const { - return down_cast(FindNode(AstNodeKind::kStructDef, span)); -} - -const EnumDef* Module::FindEnumDef(const Span& span) const { - return down_cast(FindNode(AstNodeKind::kEnumDef, span)); -} - -std::optional Module::FindMemberWithName( - std::string_view target) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->proc()->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - if (std::get(member)->identifier() == target) { - return &member; - } - } else if (std::holds_alternative(member)) { - continue; // These have no name / binding. - } else { - XLS_LOG(FATAL) << "Unhandled module member variant: " - << ToAstNode(member)->GetNodeTypeName(); - } - } - return std::nullopt; -} - -absl::StatusOr Module::GetConstantDef(std::string_view target) { - std::optional member = FindMemberWithName(target); - if (!member.has_value()) { - return absl::NotFoundError( - absl::StrFormat("Could not find member named '%s' in module.", target)); - } - if (!std::holds_alternative(*member.value())) { - return absl::NotFoundError(absl::StrFormat( - "Member named '%s' in module was not a constant.", target)); - } - return std::get(*member.value()); -} - -absl::flat_hash_map -Module::GetTypeDefinitionByName() const { - absl::flat_hash_map result; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - TypeAlias* td = std::get(member); - result[td->identifier()] = td; - } else if (std::holds_alternative(member)) { - EnumDef* enum_ = std::get(member); - result[enum_->identifier()] = enum_; - } else if (std::holds_alternative(member)) { - StructDef* struct_ = std::get(member); - result[struct_->identifier()] = struct_; - } - } - return result; -} - -std::vector Module::GetTypeDefinitions() const { - std::vector results; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - TypeAlias* td = std::get(member); - results.push_back(td); - } else if (std::holds_alternative(member)) { - EnumDef* enum_def = std::get(member); - results.push_back(enum_def); - } else if (std::holds_alternative(member)) { - StructDef* struct_def = std::get(member); - results.push_back(struct_def); - } - } - return results; -} - -std::vector Module::GetChildren(bool want_types) const { - std::vector results; - results.reserve(top_.size()); - for (ModuleMember member : top_) { - results.push_back(ToAstNode(member)); - } - return results; -} - -absl::StatusOr Module::GetTypeDefinition( - std::string_view name) const { - absl::flat_hash_map map = - GetTypeDefinitionByName(); - auto it = map.find(name); - if (it == map.end()) { - return absl::NotFoundError( - absl::StrCat("Could not find type definition for name: ", name)); - } - return it->second; -} - -absl::Status Module::AddTop(ModuleMember member, - const MakeCollisionError& make_collision_error) { - // Get name - std::optional member_name = absl::visit( - Visitor{ - [](Function* f) { return std::make_optional(f->identifier()); }, - [](Proc* p) { return std::make_optional(p->identifier()); }, - [](TestFunction* tf) { return std::make_optional(tf->identifier()); }, - [](TestProc* tp) { - return std::make_optional(tp->proc()->identifier()); - }, - [](QuickCheck* qc) { return std::make_optional(qc->identifier()); }, - [](TypeAlias* td) { return std::make_optional(td->identifier()); }, - [](StructDef* sd) { return std::make_optional(sd->identifier()); }, - [](ConstantDef* cd) { return std::make_optional(cd->identifier()); }, - [](EnumDef* ed) { return std::make_optional(ed->identifier()); }, - [](Import* i) { return std::make_optional(i->identifier()); }, - [](ConstAssert* n) -> std::optional { - return std::nullopt; - }, - }, - member); - - if (member_name.has_value() && top_by_name_.contains(member_name.value())) { - const AstNode* node = ToAstNode(top_by_name_.at(member_name.value())); - const Span existing_span = node->GetSpan().value(); - const AstNode* new_node = ToAstNode(member); - const Span new_span = new_node->GetSpan().value(); - if (make_collision_error != nullptr) { - return make_collision_error(name_, member_name.value(), existing_span, - node, new_span, new_node); - } - return absl::InvalidArgumentError(absl::StrFormat( - "Module %s already contains a member named %s @ %s: %s", name_, - member_name.value(), existing_span.ToString(), node->ToString())); - } - - top_.push_back(member); - if (member_name.has_value()) { - top_by_name_.insert({member_name.value(), member}); - } - return absl::OkStatus(); -} - -std::string_view GetModuleMemberTypeName(const ModuleMember& module_member) { - return absl::visit(Visitor{ - [](Function*) { return "function"; }, - [](Proc*) { return "proc"; }, - [](TestFunction*) { return "test-function"; }, - [](TestProc*) { return "test-proc"; }, - [](QuickCheck*) { return "quick-check"; }, - [](TypeAlias*) { return "type-alias"; }, - [](StructDef*) { return "struct-definition"; }, - [](ConstantDef*) { return "constant-definition"; }, - [](EnumDef*) { return "enum-definition"; }, - [](Import*) { return "import"; }, - [](ConstAssert*) { return "const-assert"; }, - }, - module_member); -} - -absl::StatusOr AsModuleMember(AstNode* node) { - // clang-format off - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } - // clang-format on - return absl::InvalidArgumentError("AST node is not a module-level member: " + - node->ToString()); -} - absl::StatusOr AstNodeToIndexRhs(AstNode* node) { // clang-format off if (auto* n = dynamic_cast(node)) { return IndexRhs(n); } @@ -1961,97 +1628,6 @@ std::vector Function::GetFreeParametricKeys() const { TestFunction::~TestFunction() = default; -// -- class Proc - -Proc::Proc(Module* owner, Span span, NameDef* name_def, - NameDef* config_name_def, NameDef* next_name_def, - const std::vector& parametric_bindings, - std::vector members, Function* config, Function* next, - Function* init, bool is_public) - : AstNode(owner), - span_(std::move(span)), - name_def_(name_def), - config_name_def_(config_name_def), - next_name_def_(next_name_def), - parametric_bindings_(parametric_bindings), - config_(config), - next_(next), - init_(init), - members_(std::move(members)), - is_public_(is_public) {} - -Proc::~Proc() = default; - -std::vector Proc::GetChildren(bool want_types) const { - std::vector results = {name_def()}; - for (ParametricBinding* pb : parametric_bindings_) { - results.push_back(pb); - } - for (ProcMember* p : members_) { - results.push_back(p); - } - results.push_back(config_); - results.push_back(next_); - results.push_back(init_); - return results; -} - -std::string Proc::ToString() const { - std::string pub_str = is_public() ? "pub " : ""; - std::string parametric_str; - if (!parametric_bindings().empty()) { - parametric_str = absl::StrFormat( - "<%s>", - absl::StrJoin( - parametric_bindings(), ", ", - [](std::string* out, ParametricBinding* parametric_binding) { - absl::StrAppend(out, parametric_binding->ToString()); - })); - } - auto param_append = [](std::string* out, const Param* p) { - out->append(absl::StrCat(p->ToString(), ";")); - }; - auto member_append = [](std::string* out, const ProcMember* member) { - out->append(absl::StrCat(member->ToString(), ";")); - }; - std::string config_params_str = - absl::StrJoin(config_->params(), ", ", param_append); - std::string state_params_str = - absl::StrJoin(next_->params(), ", ", param_append); - std::string members_str = absl::StrJoin(members_, "\n", member_append); - if (!members_str.empty()) { - members_str.append("\n"); - } - - // Init functions are special, since they shouldn't be printed with - // parentheses (since they can't take args). - std::string init_str = Indent( - absl::StrCat("init ", init_->body()->ToString()), kRustSpacesPerIndent); - - constexpr std::string_view kTemplate = R"(%sproc %s%s { -%s%s -%s -%s -})"; - return absl::StrFormat( - kTemplate, pub_str, name_def()->identifier(), parametric_str, - Indent(members_str, kRustSpacesPerIndent), - Indent(config_->ToUndecoratedString("config"), kRustSpacesPerIndent), - init_str, - Indent(next_->ToUndecoratedString("next"), kRustSpacesPerIndent)); -} - -std::vector Proc::GetFreeParametricKeys() const { - // TODO(rspringer): 2021-09-29: Mutants found holes in test coverage here. - std::vector results; - for (ParametricBinding* b : parametric_bindings_) { - if (b->expr() == nullptr) { - results.push_back(b->name_def()->identifier()); - } - } - return results; -} - // -- class MatchArm MatchArm::MatchArm(Module* owner, Span span, std::vector patterns, @@ -2108,14 +1684,6 @@ std::string Cast::ToStringInternal() const { return absl::StrFormat("%s as %s", lhs, type_annotation_->ToString()); } -// -- class TestProc - -TestProc::~TestProc() = default; - -std::string TestProc::ToString() const { - return absl::StrFormat("#[test_proc]\n%s", proc_->ToString()); -} - // -- class BuiltinTypeAnnotation BuiltinTypeAnnotation::BuiltinTypeAnnotation(Module* owner, Span span, @@ -2148,7 +1716,7 @@ ChannelTypeAnnotation::ChannelTypeAnnotation( : TypeAnnotation(owner, std::move(span)), direction_(direction), payload_(payload), - dims_(dims) {} + dims_(std::move(dims)) {} ChannelTypeAnnotation::~ChannelTypeAnnotation() = default; @@ -2521,43 +2089,4 @@ Span ExprOrTypeSpan(const ExprOrType &expr_or_type) { }, expr_or_type); } -absl::StatusOr> CollectUnder(AstNode* root, - bool want_types) { - std::vector nodes; - - class CollectVisitor : public AstNodeVisitor { - public: - explicit CollectVisitor(std::vector& nodes) : nodes_(nodes) {} - -#define DECLARE_HANDLER(__type) \ - absl::Status Handle##__type(const __type* n) override { \ - nodes_.push_back(const_cast<__type*>(n)); \ - return absl::OkStatus(); \ - } - XLS_DSLX_AST_NODE_EACH(DECLARE_HANDLER) -#undef DECLARE_HANDLER - - private: - std::vector& nodes_; - } collect_visitor(nodes); - - XLS_RETURN_IF_ERROR(WalkPostOrder(root, &collect_visitor, want_types)); - return nodes; -} - -absl::StatusOr> CollectUnder(const AstNode* root, - bool want_types) { - // Implementation note: delegate to non-const version and turn result values - // back to const. - XLS_ASSIGN_OR_RETURN(std::vector got, - CollectUnder(const_cast(root), want_types)); - - std::vector result; - result.reserve(got.size()); - for (AstNode* n : got) { - result.push_back(n); - } - return result; -} - } // namespace xls::dslx diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 7538764c9f..30c88f2b5a 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -1625,106 +1625,6 @@ class Function : public AstNode { std::optional<::xls::ForeignFunctionData> extern_verilog_module_; }; -// A member held in a proc, e.g. a channel declaration initialized by a -// configuration block. -// -// This is very similar to a `Param` at the moment, but we make them distinct -// types for structural clarity in the AST. Params are really "parameters to -// functions". -class ProcMember : public AstNode { - public: - ProcMember(Module* owner, NameDef* name_def, TypeAnnotation* type); - - ~ProcMember() override; - - AstNodeKind kind() const override { return AstNodeKind::kProcMember; } - - absl::Status Accept(AstNodeVisitor* v) const override { - return v->HandleProcMember(this); - } - - std::string_view GetNodeTypeName() const override { return "ProcMember"; } - std::string ToString() const override { - return absl::StrFormat("%s: %s", name_def_->ToString(), - type_annotation_->ToString()); - } - - std::vector GetChildren(bool want_types) const override { - return {name_def_, type_annotation_}; - } - - const Span& span() const { return span_; } - NameDef* name_def() const { return name_def_; } - TypeAnnotation* type_annotation() const { return type_annotation_; } - const std::string& identifier() const { return name_def_->identifier(); } - std::optional GetSpan() const override { return span_; } - - private: - NameDef* name_def_; - TypeAnnotation* type_annotation_; - Span span_; -}; - -// Represents a parsed 'process' specification in the DSL. -class Proc : public AstNode { - public: - static std::string_view GetDebugTypeName() { return "proc"; } - - Proc(Module* owner, Span span, NameDef* name_def, NameDef* config_name_def, - NameDef* next_name_def, - const std::vector& parametric_bindings, - std::vector members, Function* config, Function* next, - Function* init, bool is_public); - - ~Proc() override; - - AstNodeKind kind() const override { return AstNodeKind::kProc; } - - absl::Status Accept(AstNodeVisitor* v) const override { - return v->HandleProc(this); - } - std::string_view GetNodeTypeName() const override { return "Proc"; } - std::string ToString() const override; - std::vector GetChildren(bool want_types) const override; - - NameDef* name_def() const { return name_def_; } - NameDef* config_name_def() const { return config_name_def_; } - NameDef* next_name_def() const { return next_name_def_; } - const Span& span() const { return span_; } - std::optional GetSpan() const override { return span_; } - - const std::string& identifier() const { return name_def_->identifier(); } - const std::vector& parametric_bindings() const { - return parametric_bindings_; - } - bool IsParametric() const { return !parametric_bindings_.empty(); } - bool is_public() const { return is_public_; } - - std::vector GetFreeParametricKeys() const; - absl::btree_set GetFreeParametricKeySet() const { - std::vector keys = GetFreeParametricKeys(); - return absl::btree_set(keys.begin(), keys.end()); - } - - Function* config() const { return config_; } - Function* next() const { return next_; } - Function* init() const { return init_; } - const std::vector& members() const { return members_; } - - private: - Span span_; - NameDef* name_def_; - NameDef* config_name_def_; - NameDef* next_name_def_; - std::vector parametric_bindings_; - - Function* config_; - Function* next_; - Function* init_; - std::vector members_; - bool is_public_; -}; - // Represents a single arm in a match expression. // // Attributes: @@ -2501,38 +2401,6 @@ class TestFunction : public AstNode { Function* const fn_; }; -// Represents a construct to unit test a Proc. Analogous to TestFunction, but -// for Procs. -// -// These are specified with an annotation as follows: -// ```dslx -// #[test_proc()] -// proc test_proc { ... } -// ``` -class TestProc : public AstNode { - public: - TestProc(Module* owner, Proc* proc) : AstNode(owner), proc_(proc) {} - ~TestProc() override; - - AstNodeKind kind() const override { return AstNodeKind::kTestProc; } - absl::Status Accept(AstNodeVisitor* v) const override { - return v->HandleTestProc(this); - } - std::vector GetChildren(bool want_types) const override { - return {proc_}; - } - std::string_view GetNodeTypeName() const override { return "TestProc"; } - std::string ToString() const override; - - Proc* proc() const { return proc_; } - std::optional GetSpan() const override { return proc_->span(); } - - const std::string& identifier() const { return proc_->identifier(); } - - private: - Proc* proc_; -}; - // Represents a function to be quick-check'd. class QuickCheck : public AstNode { public: @@ -3074,248 +2942,10 @@ class ChannelDecl : public Expr { std::optional fifo_depth_; }; -using ModuleMember = - std::variant; - -std::string_view GetModuleMemberTypeName(const ModuleMember& module_member); - -absl::StatusOr AsModuleMember(AstNode* node); - -// Represents a syntactic module in the AST. -// -// Modules contain top-level definitions such as functions and tests. -// -// Attributes: -// name: Name of this module. -// top: Top-level module constructs; e.g. functions, tests. Given as a -// sequence instead of a mapping in case there are unnamed constructs at the -// module level (e.g. metadata, docstrings). -// fs_path: Name of the filesystem path that led to this module's AST -- if -// the AST was constructed in-memory this value will be nullopt. Generally -// this was relative to the main binary's $CWD (which is often a place like -// Bazel's execution root) -- this helps output be deterministic even when -// running distributed compilation. -class Module : public AstNode { - public: - Module(std::string name, std::optional fs_path) - : AstNode(this), name_(std::move(name)), fs_path_(std::move(fs_path)) { - XLS_VLOG(3) << "Created module \"" << name_ << "\" @ " << this; - } - - ~Module() override; - - Module(Module&& other) = default; - Module& operator=(Module&& other) = default; - - AstNodeKind kind() const override { return AstNodeKind::kModule; } - - absl::Status Accept(AstNodeVisitor* v) const override { - return v->HandleModule(this); - } - std::optional GetSpan() const override { return std::nullopt; } - - std::string_view GetNodeTypeName() const override { return "Module"; } - std::vector GetChildren(bool want_types) const override; - - std::string ToString() const override; - - template - T* Make(Args&&... args) { - static_assert(!std::is_same::value, - "Use Module::GetOrCreateBuiltinNameDef()"); - return MakeInternal(std::forward(args)...); - } - - BuiltinNameDef* GetOrCreateBuiltinNameDef(std::string_view name) { - auto it = builtin_name_defs_.find(name); - if (it == builtin_name_defs_.end()) { - BuiltinNameDef* bnd = MakeInternal(std::string(name)); - builtin_name_defs_.emplace_hint(it, std::string(name), bnd); - return bnd; - } - return it->second; - } - - using MakeCollisionError = std::function; - - // Adds a top level "member" to the module. Invokes make_collision_error if - // there is a naming collision at module scope -- this is done so that errors - // can be layered appropriately and injected in from outside code (e.g. the - // parser). If nullptr is given, then a non-positional InvalidArgumentError is - // raised. - absl::Status AddTop(ModuleMember member, - const MakeCollisionError& make_collision_error); - - // Gets the element in this module with the given target_name, or returns a - // NotFoundError. - template - absl::StatusOr GetMemberOrError(std::string_view target_name) { - for (ModuleMember& member : top_) { - if (std::holds_alternative(member)) { - T* t = std::get(member); - if (t->identifier() == target_name) { - return t; - } - } - } - - return absl::NotFoundError( - absl::StrFormat("No %s in module %s with name \"%s\"", typeid(T).name(), - name_, target_name)); - } - - std::optional GetFunction(std::string_view target_name); - std::optional GetProc(std::string_view target_name); - - // Gets a test construct in this module with the given "target_name", or - // returns a NotFoundError. - absl::StatusOr GetTest(std::string_view target_name); - absl::StatusOr GetTestProc(std::string_view target_name); - - absl::Span top() const { return top_; } - - // Finds the first top-level member in top() with the given "target" name as - // an identifier. - std::optional FindMemberWithName(std::string_view target); - - const StructDef* FindStructDef(const Span& span) const; - - const EnumDef* FindEnumDef(const Span& span) const; - - // Obtains all the type definition nodes in the module; e.g. TypeAlias, - // StructDef, EnumDef. - absl::flat_hash_map GetTypeDefinitionByName() - const; - - // Obtains all the type definition nodes in the module in module-member order. - std::vector GetTypeDefinitions() const; - - absl::StatusOr GetTypeDefinition( - std::string_view name) const; - - // Retrieves a constant node from this module with the target name as its - // identifier, or a NotFound error if none can be found. - absl::StatusOr GetConstantDef(std::string_view target); - - absl::flat_hash_map GetImportByName() const { - return GetTopWithTByName(); - } - absl::flat_hash_map GetFunctionByName() const { - return GetTopWithTByName(); - } - std::vector GetQuickChecks() const { - return GetTopWithT(); - } - std::vector GetStructDefs() const { - return GetTopWithT(); - } - std::vector GetProcs() const { return GetTopWithT(); } - - // Returns the identifiers for all functions within this module (in the order - // in which they are defined). - std::vector GetFunctionNames() const; - - // Returns the identifiers for all tests within this module (in the order in - // which they are defined). - std::vector GetTestNames() const; - - const std::string& name() const { return name_; } - const std::optional& fs_path() const { - return fs_path_; - } - - // Finds a node with the given kind and /exactly/ the same span as "target". - const AstNode* FindNode(AstNodeKind kind, const Span& target) const; - - // Finds all the AST nodes in the module with spans that intercept the given - // "target" position. - std::vector FindIntercepting(const Pos& target) const; - - private: - template - T* MakeInternal(Args&&... args) { - std::unique_ptr node = - std::make_unique(this, std::forward(args)...); - T* ptr = node.get(); - ptr->SetParentage(); - nodes_.push_back(std::move(node)); - return ptr; - } - - // Returns all of the elements of top_ that have the given variant type T. - template - std::vector GetTopWithT() const { - std::vector result; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - result.push_back(std::get(member)); - } - } - return result; - } - - // Returns all the elements of top_ that have the given variant type T, using - // T's identifier as a key. (T must have a string identifier.) - template - absl::flat_hash_map GetTopWithTByName() const { - absl::flat_hash_map result; - for (auto& member : top_) { - if (std::holds_alternative(member)) { - auto* c = std::get(member); - result.insert({c->identifier(), c}); - } - } - return result; - } - - std::string name_; // Name of this module. - - // Optional filesystem path (may not be present for e.g. DSLX files created in - // memory). - std::optional fs_path_; - - std::vector top_; // Top-level members of this module. - std::vector> nodes_; // Lifetime-owned AST nodes. - - // Map of top-level module member name to the member itself. - absl::flat_hash_map top_by_name_; - - // Builtin name definitions, which we common out on a per-module basis. Not - // for any particular purpose at this time aside from cleanliness of not - // having many definition nodes of the same builtin thing floating around. - absl::flat_hash_map builtin_name_defs_; -}; - // Helper for determining whether an AST node is constant (e.g. can be // considered a constant value in a ConstantArray). bool IsConstant(AstNode* n); -// Helper for making a ternary expression conditional. This avoids the user -// needing to hand-craft the block nodes and such. -inline Conditional* MakeTernary(Module* module, const Span& span, Expr* test, - Expr* consequent, Expr* alternate) { - return module->Make( - span, test, - module->Make( - consequent->span(), - std::vector{module->Make(consequent)}, false), - module->Make( - alternate->span(), - std::vector{module->Make(alternate)}, false)); -} - -// Collects all nodes under the given root. -absl::StatusOr> CollectUnder(AstNode* root, - bool want_types); - -absl::StatusOr> CollectUnder(const AstNode* root, - bool want_types); - } // namespace xls::dslx #endif // XLS_DSLX_FRONTEND_AST_H_ diff --git a/xls/dslx/frontend/ast_cloner.cc b/xls/dslx/frontend/ast_cloner.cc index a46fa66c70..802e19c98b 100644 --- a/xls/dslx/frontend/ast_cloner.cc +++ b/xls/dslx/frontend/ast_cloner.cc @@ -33,6 +33,8 @@ #include "xls/common/visitor.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_utils.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/frontend/proc.h" namespace xls::dslx { namespace { diff --git a/xls/dslx/frontend/ast_cloner_test.cc b/xls/dslx/frontend/ast_cloner_test.cc index 8fc7d66f61..8d0c31c198 100644 --- a/xls/dslx/frontend/ast_cloner_test.cc +++ b/xls/dslx/frontend/ast_cloner_test.cc @@ -22,6 +22,7 @@ #include "xls/common/status/matchers.h" #include "xls/dslx/command_line_utils.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/parse_and_typecheck.h" namespace xls::dslx { diff --git a/xls/dslx/frontend/ast_test.cc b/xls/dslx/frontend/ast_test.cc index 93f4da3827..5f30e89341 100644 --- a/xls/dslx/frontend/ast_test.cc +++ b/xls/dslx/frontend/ast_test.cc @@ -24,6 +24,7 @@ #include "absl/status/status.h" #include "xls/common/status/matchers.h" #include "xls/dslx/frontend/ast_test_utils.h" +#include "xls/dslx/frontend/module.h" namespace xls::dslx { namespace { diff --git a/xls/dslx/frontend/ast_test_utils.cc b/xls/dslx/frontend/ast_test_utils.cc index ddf7691b23..17f4bcfa04 100644 --- a/xls/dslx/frontend/ast_test_utils.cc +++ b/xls/dslx/frontend/ast_test_utils.cc @@ -24,6 +24,7 @@ #include "absl/strings/str_format.h" #include "xls/common/logging/logging.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" namespace xls::dslx { diff --git a/xls/dslx/frontend/ast_utils.cc b/xls/dslx/frontend/ast_utils.cc index 810acfac24..727e7575eb 100644 --- a/xls/dslx/frontend/ast_utils.cc +++ b/xls/dslx/frontend/ast_utils.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -576,4 +577,43 @@ std::optional ExtractBitVectorMetadata( return std::nullopt; } +absl::StatusOr> CollectUnder(AstNode* root, + bool want_types) { + std::vector nodes; + + class CollectVisitor : public AstNodeVisitor { + public: + explicit CollectVisitor(std::vector& nodes) : nodes_(nodes) {} + +#define DECLARE_HANDLER(__type) \ + absl::Status Handle##__type(const __type* n) override { \ + nodes_.push_back(const_cast<__type*>(n)); \ + return absl::OkStatus(); \ + } + XLS_DSLX_AST_NODE_EACH(DECLARE_HANDLER) +#undef DECLARE_HANDLER + + private: + std::vector& nodes_; + } collect_visitor(nodes); + + XLS_RETURN_IF_ERROR(WalkPostOrder(root, &collect_visitor, want_types)); + return nodes; +} + +absl::StatusOr> CollectUnder(const AstNode* root, + bool want_types) { + // Implementation note: delegate to non-const version and turn result values + // back to const. + XLS_ASSIGN_OR_RETURN(std::vector got, + CollectUnder(const_cast(root), want_types)); + + std::vector result; + result.reserve(got.size()); + for (AstNode* n : got) { + result.push_back(n); + } + return result; +} + } // namespace xls::dslx diff --git a/xls/dslx/frontend/ast_utils.h b/xls/dslx/frontend/ast_utils.h index a011029f26..b63e307a42 100644 --- a/xls/dslx/frontend/ast_utils.h +++ b/xls/dslx/frontend/ast_utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -209,6 +210,13 @@ struct BitVectorMetadata { std::optional ExtractBitVectorMetadata( const TypeAnnotation* type_annotation); +// Collects all nodes under the given root. +absl::StatusOr> CollectUnder(AstNode* root, + bool want_types); + +absl::StatusOr> CollectUnder(const AstNode* root, + bool want_types); + } // namespace xls::dslx #endif // XLS_DSLX_FRONTEND_AST_UTILS_H_ diff --git a/xls/dslx/frontend/module.cc b/xls/dslx/frontend/module.cc new file mode 100644 index 0000000000..4680009c72 --- /dev/null +++ b/xls/dslx/frontend/module.cc @@ -0,0 +1,370 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/frontend/module.h" + +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/variant.h" +#include "xls/common/casts.h" +#include "xls/common/logging/logging.h" +#include "xls/common/status/ret_check.h" +#include "xls/common/visitor.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" + +namespace xls::dslx { + +// -- class Module + +Module::Module(std::string name, std::optional fs_path) + : AstNode(this), name_(std::move(name)), fs_path_(std::move(fs_path)) { + XLS_VLOG(3) << "Created module \"" << name_ << "\" @ " << this; +} + +Module::~Module() { + XLS_VLOG(3) << "Destroying module \"" << name_ << "\" @ " << this; +} + +std::string Module::ToString() const { + // Don't print Proc functions, as they'll be printed as part of the procs + // themselves. + std::vector print_top; + for (const auto& member : top_) { + if (std::holds_alternative(member) && + std::get(member)->proc().has_value()) { + continue; + } + print_top.push_back(member); + } + return absl::StrJoin(print_top, "\n", + [](std::string* out, const ModuleMember& member) { + absl::StrAppend(out, ToAstNode(member)->ToString()); + }); +} + +const AstNode* Module::FindNode(AstNodeKind kind, const Span& target) const { + for (const auto& node : nodes_) { + if (node->kind() == kind && node->GetSpan().has_value() && + node->GetSpan().value() == target) { + return node.get(); + } + } + return nullptr; +} + +std::vector Module::FindIntercepting(const Pos& target) const { + std::vector found; + for (const auto& node : nodes_) { + if (node->GetSpan().has_value() && node->GetSpan()->Contains(target)) { + found.push_back(node.get()); + } + } + return found; +} + +std::optional Module::GetFunction(std::string_view target_name) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + Function* f = std::get(member); + if (f->identifier() == target_name) { + return f; + } + } + } + return std::nullopt; +} + +std::optional Module::GetProc(std::string_view target_name) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + Proc* p = std::get(member); + if (p->identifier() == target_name) { + return p; + } + } + } + return std::nullopt; +} + +absl::StatusOr Module::GetTest(std::string_view target_name) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + TestFunction* t = std::get(member); + if (t->identifier() == target_name) { + return t; + } + } + } + return absl::NotFoundError(absl::StrFormat( + "No test in module %s with name \"%s\"", name_, target_name)); +} + +absl::StatusOr Module::GetTestProc(std::string_view target_name) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + auto* t = std::get(member); + if (t->proc()->identifier() == target_name) { + return t; + } + } + } + return absl::NotFoundError(absl::StrFormat( + "No test proc in module %s with name \"%s\"", name_, target_name)); +} + +std::vector Module::GetTestNames() const { + std::vector result; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + TestFunction* t = std::get(member); + result.push_back(t->identifier()); + } else if (std::holds_alternative(member)) { + TestProc* tp = std::get(member); + result.push_back(tp->proc()->identifier()); + } + } + return result; +} + +std::vector Module::GetFunctionNames() const { + std::vector result; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + result.push_back(std::get(member)->identifier()); + } + } + return result; +} + +const StructDef* Module::FindStructDef(const Span& span) const { + return down_cast(FindNode(AstNodeKind::kStructDef, span)); +} + +const EnumDef* Module::FindEnumDef(const Span& span) const { + return down_cast(FindNode(AstNodeKind::kEnumDef, span)); +} + +std::optional Module::FindMemberWithName( + std::string_view target) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->proc()->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + if (std::get(member)->identifier() == target) { + return &member; + } + } else if (std::holds_alternative(member)) { + continue; // These have no name / binding. + } else { + XLS_LOG(FATAL) << "Unhandled module member variant: " + << ToAstNode(member)->GetNodeTypeName(); + } + } + return std::nullopt; +} + +absl::StatusOr Module::GetConstantDef(std::string_view target) { + std::optional member = FindMemberWithName(target); + if (!member.has_value()) { + return absl::NotFoundError( + absl::StrFormat("Could not find member named '%s' in module.", target)); + } + if (!std::holds_alternative(*member.value())) { + return absl::NotFoundError(absl::StrFormat( + "Member named '%s' in module was not a constant.", target)); + } + return std::get(*member.value()); +} + +absl::flat_hash_map +Module::GetTypeDefinitionByName() const { + absl::flat_hash_map result; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + TypeAlias* td = std::get(member); + result[td->identifier()] = td; + } else if (std::holds_alternative(member)) { + EnumDef* enum_ = std::get(member); + result[enum_->identifier()] = enum_; + } else if (std::holds_alternative(member)) { + StructDef* struct_ = std::get(member); + result[struct_->identifier()] = struct_; + } + } + return result; +} + +std::vector Module::GetTypeDefinitions() const { + std::vector results; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + TypeAlias* td = std::get(member); + results.push_back(td); + } else if (std::holds_alternative(member)) { + EnumDef* enum_def = std::get(member); + results.push_back(enum_def); + } else if (std::holds_alternative(member)) { + StructDef* struct_def = std::get(member); + results.push_back(struct_def); + } + } + return results; +} + +std::vector Module::GetChildren(bool want_types) const { + std::vector results; + results.reserve(top_.size()); + for (ModuleMember member : top_) { + results.push_back(ToAstNode(member)); + } + return results; +} + +absl::StatusOr Module::GetTypeDefinition( + std::string_view name) const { + absl::flat_hash_map map = + GetTypeDefinitionByName(); + auto it = map.find(name); + if (it == map.end()) { + return absl::NotFoundError( + absl::StrCat("Could not find type definition for name: ", name)); + } + return it->second; +} + +absl::Status Module::AddTop(ModuleMember member, + const MakeCollisionError& make_collision_error) { + // Get name + std::optional member_name = absl::visit( + Visitor{ + [](Function* f) { return std::make_optional(f->identifier()); }, + [](Proc* p) { return std::make_optional(p->identifier()); }, + [](TestFunction* tf) { return std::make_optional(tf->identifier()); }, + [](TestProc* tp) { + return std::make_optional(tp->proc()->identifier()); + }, + [](QuickCheck* qc) { return std::make_optional(qc->identifier()); }, + [](TypeAlias* td) { return std::make_optional(td->identifier()); }, + [](StructDef* sd) { return std::make_optional(sd->identifier()); }, + [](ConstantDef* cd) { return std::make_optional(cd->identifier()); }, + [](EnumDef* ed) { return std::make_optional(ed->identifier()); }, + [](Import* i) { return std::make_optional(i->identifier()); }, + [](ConstAssert* n) -> std::optional { + return std::nullopt; + }, + }, + member); + + if (member_name.has_value() && top_by_name_.contains(member_name.value())) { + const AstNode* node = ToAstNode(top_by_name_.at(member_name.value())); + const Span existing_span = node->GetSpan().value(); + const AstNode* new_node = ToAstNode(member); + const Span new_span = new_node->GetSpan().value(); + if (make_collision_error != nullptr) { + return make_collision_error(name_, member_name.value(), existing_span, + node, new_span, new_node); + } + return absl::InvalidArgumentError(absl::StrFormat( + "Module %s already contains a member named %s @ %s: %s", name_, + member_name.value(), existing_span.ToString(), node->ToString())); + } + + top_.push_back(member); + if (member_name.has_value()) { + top_by_name_.insert({member_name.value(), member}); + } + return absl::OkStatus(); +} + +std::string_view GetModuleMemberTypeName(const ModuleMember& module_member) { + return absl::visit(Visitor{ + [](Function*) { return "function"; }, + [](Proc*) { return "proc"; }, + [](TestFunction*) { return "test-function"; }, + [](TestProc*) { return "test-proc"; }, + [](QuickCheck*) { return "quick-check"; }, + [](TypeAlias*) { return "type-alias"; }, + [](StructDef*) { return "struct-definition"; }, + [](ConstantDef*) { return "constant-definition"; }, + [](EnumDef*) { return "enum-definition"; }, + [](Import*) { return "import"; }, + [](ConstAssert*) { return "const-assert"; }, + }, + module_member); +} + +absl::StatusOr AsModuleMember(AstNode* node) { + XLS_RET_CHECK(node != nullptr); + // clang-format off + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + if (auto* n = dynamic_cast(node)) { return ModuleMember(n); } + // clang-format on + return absl::InvalidArgumentError("AST node is not a module-level member: " + + node->ToString()); +} + +} // namespace xls::dslx diff --git a/xls/dslx/frontend/module.h b/xls/dslx/frontend/module.h new file mode 100644 index 0000000000..08442a4598 --- /dev/null +++ b/xls/dslx/frontend/module.h @@ -0,0 +1,268 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_DSLX_FRONTEND_MODULE_H_ +#define XLS_DSLX_FRONTEND_MODULE_H_ + +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" +#include "xls/dslx/frontend/proc.h" + +namespace xls::dslx { + +using ModuleMember = + std::variant; + +std::string_view GetModuleMemberTypeName(const ModuleMember& module_member); + +absl::StatusOr AsModuleMember(AstNode* node); + +// Represents a syntactic module in the AST. +// +// Modules contain top-level definitions such as functions and tests. +// +// Attributes: +// name: Name of this module. +// top: Top-level module constructs; e.g. functions, tests. Given as a +// sequence instead of a mapping in case there are unnamed constructs at the +// module level (e.g. metadata, docstrings). +// fs_path: Name of the filesystem path that led to this module's AST -- if +// the AST was constructed in-memory this value will be nullopt. Generally +// this was relative to the main binary's $CWD (which is often a place like +// Bazel's execution root) -- this helps output be deterministic even when +// running distributed compilation. +class Module : public AstNode { + public: + Module(std::string name, std::optional fs_path); + + ~Module() override; + + Module(Module&& other) = default; + Module& operator=(Module&& other) = default; + + AstNodeKind kind() const override { return AstNodeKind::kModule; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleModule(this); + } + std::optional GetSpan() const override { return std::nullopt; } + + std::string_view GetNodeTypeName() const override { return "Module"; } + std::vector GetChildren(bool want_types) const override; + + std::string ToString() const override; + + template + T* Make(Args&&... args) { + static_assert(!std::is_same::value, + "Use Module::GetOrCreateBuiltinNameDef()"); + return MakeInternal(std::forward(args)...); + } + + BuiltinNameDef* GetOrCreateBuiltinNameDef(std::string_view name) { + auto it = builtin_name_defs_.find(name); + if (it == builtin_name_defs_.end()) { + BuiltinNameDef* bnd = MakeInternal(std::string(name)); + builtin_name_defs_.emplace_hint(it, std::string(name), bnd); + return bnd; + } + return it->second; + } + + using MakeCollisionError = std::function; + + // Adds a top level "member" to the module. Invokes make_collision_error if + // there is a naming collision at module scope -- this is done so that errors + // can be layered appropriately and injected in from outside code (e.g. the + // parser). If nullptr is given, then a non-positional InvalidArgumentError is + // raised. + absl::Status AddTop(ModuleMember member, + const MakeCollisionError& make_collision_error); + + // Gets the element in this module with the given target_name, or returns a + // NotFoundError. + template + absl::StatusOr GetMemberOrError(std::string_view target_name) { + for (ModuleMember& member : top_) { + if (std::holds_alternative(member)) { + T* t = std::get(member); + if (t->identifier() == target_name) { + return t; + } + } + } + + return absl::NotFoundError( + absl::StrFormat("No %s in module %s with name \"%s\"", typeid(T).name(), + name_, target_name)); + } + + std::optional GetFunction(std::string_view target_name); + std::optional GetProc(std::string_view target_name); + + // Gets a test construct in this module with the given "target_name", or + // returns a NotFoundError. + absl::StatusOr GetTest(std::string_view target_name); + absl::StatusOr GetTestProc(std::string_view target_name); + + absl::Span top() const { return top_; } + + // Finds the first top-level member in top() with the given "target" name as + // an identifier. + std::optional FindMemberWithName(std::string_view target); + + const StructDef* FindStructDef(const Span& span) const; + + const EnumDef* FindEnumDef(const Span& span) const; + + // Obtains all the type definition nodes in the module; e.g. TypeAlias, + // StructDef, EnumDef. + absl::flat_hash_map GetTypeDefinitionByName() + const; + + // Obtains all the type definition nodes in the module in module-member order. + std::vector GetTypeDefinitions() const; + + absl::StatusOr GetTypeDefinition(std::string_view name) const; + + // Retrieves a constant node from this module with the target name as its + // identifier, or a NotFound error if none can be found. + absl::StatusOr GetConstantDef(std::string_view target); + + absl::flat_hash_map GetImportByName() const { + return GetTopWithTByName(); + } + absl::flat_hash_map GetFunctionByName() const { + return GetTopWithTByName(); + } + std::vector GetQuickChecks() const { + return GetTopWithT(); + } + std::vector GetStructDefs() const { + return GetTopWithT(); + } + std::vector GetProcs() const { return GetTopWithT(); } + + // Returns the identifiers for all functions within this module (in the order + // in which they are defined). + std::vector GetFunctionNames() const; + + // Returns the identifiers for all tests within this module (in the order in + // which they are defined). + std::vector GetTestNames() const; + + const std::string& name() const { return name_; } + const std::optional& fs_path() const { + return fs_path_; + } + + // Finds a node with the given kind and /exactly/ the same span as "target". + const AstNode* FindNode(AstNodeKind kind, const Span& target) const; + + // Finds all the AST nodes in the module with spans that intercept the given + // "target" position. + std::vector FindIntercepting(const Pos& target) const; + + private: + template + T* MakeInternal(Args&&... args) { + std::unique_ptr node = + std::make_unique(this, std::forward(args)...); + T* ptr = node.get(); + ptr->SetParentage(); + nodes_.push_back(std::move(node)); + return ptr; + } + + // Returns all of the elements of top_ that have the given variant type T. + template + std::vector GetTopWithT() const { + std::vector result; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + result.push_back(std::get(member)); + } + } + return result; + } + + // Returns all the elements of top_ that have the given variant type T, using + // T's identifier as a key. (T must have a string identifier.) + template + absl::flat_hash_map GetTopWithTByName() const { + absl::flat_hash_map result; + for (auto& member : top_) { + if (std::holds_alternative(member)) { + auto* c = std::get(member); + result.insert({c->identifier(), c}); + } + } + return result; + } + + std::string name_; // Name of this module. + + // Optional filesystem path (may not be present e.g. for DSLX files created in + // memory). + std::optional fs_path_; + + std::vector top_; // Top-level members of this module. + std::vector> nodes_; // Lifetime-owned AST nodes. + + // Map of top-level module member name to the member itself. + absl::flat_hash_map top_by_name_; + + // Builtin name definitions, which we common out on a per-module basis. Not + // for any particular purpose at this time aside from cleanliness of not + // having many definition nodes of the same builtin thing floating around. + absl::flat_hash_map builtin_name_defs_; +}; + +// Helper for making a ternary expression conditional. This avoids the user +// needing to hand-craft the block nodes and such. +inline Conditional* MakeTernary(Module* module, const Span& span, Expr* test, + Expr* consequent, Expr* alternate) { + return module->Make( + span, test, + module->Make( + consequent->span(), + std::vector{module->Make(consequent)}, false), + module->Make( + alternate->span(), + std::vector{module->Make(alternate)}, false)); +} + +} // namespace xls::dslx + +#endif // XLS_DSLX_FRONTEND_MODULE_H_ diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index 541ea83d30..66ce605196 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -37,6 +37,7 @@ #include "xls/common/strong_int.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/bindings.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/frontend/scanner.h" #include "xls/dslx/frontend/token.h" diff --git a/xls/dslx/frontend/proc.cc b/xls/dslx/frontend/proc.cc new file mode 100644 index 0000000000..51bd79e2f2 --- /dev/null +++ b/xls/dslx/frontend/proc.cc @@ -0,0 +1,141 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/frontend/proc.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "xls/common/indent.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" + +namespace xls::dslx { + +// -- class Proc + +Proc::Proc(Module* owner, Span span, NameDef* name_def, + NameDef* config_name_def, NameDef* next_name_def, + const std::vector& parametric_bindings, + std::vector members, Function* config, Function* next, + Function* init, bool is_public) + : AstNode(owner), + span_(std::move(span)), + name_def_(name_def), + config_name_def_(config_name_def), + next_name_def_(next_name_def), + parametric_bindings_(parametric_bindings), + config_(config), + next_(next), + init_(init), + members_(std::move(members)), + is_public_(is_public) {} + +Proc::~Proc() = default; + +std::vector Proc::GetChildren(bool want_types) const { + std::vector results = {name_def()}; + for (ParametricBinding* pb : parametric_bindings_) { + results.push_back(pb); + } + for (ProcMember* p : members_) { + results.push_back(p); + } + results.push_back(config_); + results.push_back(next_); + results.push_back(init_); + return results; +} + +std::string Proc::ToString() const { + std::string pub_str = is_public() ? "pub " : ""; + std::string parametric_str; + if (!parametric_bindings().empty()) { + parametric_str = absl::StrFormat( + "<%s>", + absl::StrJoin( + parametric_bindings(), ", ", + [](std::string* out, ParametricBinding* parametric_binding) { + absl::StrAppend(out, parametric_binding->ToString()); + })); + } + auto param_append = [](std::string* out, const Param* p) { + out->append(absl::StrCat(p->ToString(), ";")); + }; + auto member_append = [](std::string* out, const ProcMember* member) { + out->append(absl::StrCat(member->ToString(), ";")); + }; + std::string config_params_str = + absl::StrJoin(config_->params(), ", ", param_append); + std::string state_params_str = + absl::StrJoin(next_->params(), ", ", param_append); + std::string members_str = absl::StrJoin(members_, "\n", member_append); + if (!members_str.empty()) { + members_str.append("\n"); + } + + // Init functions are special, since they shouldn't be printed with + // parentheses (since they can't take args). + std::string init_str = Indent( + absl::StrCat("init ", init_->body()->ToString()), kRustSpacesPerIndent); + + constexpr std::string_view kTemplate = R"(%sproc %s%s { +%s%s +%s +%s +})"; + return absl::StrFormat( + kTemplate, pub_str, name_def()->identifier(), parametric_str, + Indent(members_str, kRustSpacesPerIndent), + Indent(config_->ToUndecoratedString("config"), kRustSpacesPerIndent), + init_str, + Indent(next_->ToUndecoratedString("next"), kRustSpacesPerIndent)); +} + +std::vector Proc::GetFreeParametricKeys() const { + // TODO(rspringer): 2021-09-29: Mutants found holes in test coverage here. + std::vector results; + for (ParametricBinding* b : parametric_bindings_) { + if (b->expr() == nullptr) { + results.push_back(b->name_def()->identifier()); + } + } + return results; +} + +// -- class TestProc + +TestProc::~TestProc() = default; + +std::string TestProc::ToString() const { + return absl::StrFormat("#[test_proc]\n%s", proc_->ToString()); +} + +// -- class ProcMember + +ProcMember::ProcMember(Module* owner, NameDef* name_def, + TypeAnnotation* type_annotation) + : AstNode(owner), + name_def_(name_def), + type_annotation_(type_annotation), + span_(name_def_->span().start(), type_annotation_->span().limit()) {} + +ProcMember::~ProcMember() = default; + +} // namespace xls::dslx diff --git a/xls/dslx/frontend/proc.h b/xls/dslx/frontend/proc.h new file mode 100644 index 0000000000..9ad9ad67c5 --- /dev/null +++ b/xls/dslx/frontend/proc.h @@ -0,0 +1,165 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_DSLX_FRONTEND_PROC_H_ +#define XLS_DSLX_FRONTEND_PROC_H_ + +#include +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" + +namespace xls::dslx { + +// A member held in a proc, e.g. a channel declaration initialized by a +// configuration block. +// +// This is very similar to a `Param` at the moment, but we make them distinct +// types for structural clarity in the AST. Params are really "parameters to +// functions". +class ProcMember : public AstNode { + public: + ProcMember(Module* owner, NameDef* name_def, TypeAnnotation* type); + + ~ProcMember() override; + + AstNodeKind kind() const override { return AstNodeKind::kProcMember; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleProcMember(this); + } + + std::string_view GetNodeTypeName() const override { return "ProcMember"; } + std::string ToString() const override { + return absl::StrFormat("%s: %s", name_def_->ToString(), + type_annotation_->ToString()); + } + + std::vector GetChildren(bool want_types) const override { + return {name_def_, type_annotation_}; + } + + const Span& span() const { return span_; } + NameDef* name_def() const { return name_def_; } + TypeAnnotation* type_annotation() const { return type_annotation_; } + const std::string& identifier() const { return name_def_->identifier(); } + std::optional GetSpan() const override { return span_; } + + private: + NameDef* name_def_; + TypeAnnotation* type_annotation_; + Span span_; +}; + +// Represents a parsed 'process' specification in the DSL. +class Proc : public AstNode { + public: + static std::string_view GetDebugTypeName() { return "proc"; } + + Proc(Module* owner, Span span, NameDef* name_def, NameDef* config_name_def, + NameDef* next_name_def, + const std::vector& parametric_bindings, + std::vector members, Function* config, Function* next, + Function* init, bool is_public); + + ~Proc() override; + + AstNodeKind kind() const override { return AstNodeKind::kProc; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleProc(this); + } + std::string_view GetNodeTypeName() const override { return "Proc"; } + std::string ToString() const override; + std::vector GetChildren(bool want_types) const override; + + NameDef* name_def() const { return name_def_; } + NameDef* config_name_def() const { return config_name_def_; } + NameDef* next_name_def() const { return next_name_def_; } + const Span& span() const { return span_; } + std::optional GetSpan() const override { return span_; } + + const std::string& identifier() const { return name_def_->identifier(); } + const std::vector& parametric_bindings() const { + return parametric_bindings_; + } + bool IsParametric() const { return !parametric_bindings_.empty(); } + bool is_public() const { return is_public_; } + + std::vector GetFreeParametricKeys() const; + absl::btree_set GetFreeParametricKeySet() const { + std::vector keys = GetFreeParametricKeys(); + return absl::btree_set(keys.begin(), keys.end()); + } + + Function* config() const { return config_; } + Function* next() const { return next_; } + Function* init() const { return init_; } + const std::vector& members() const { return members_; } + + private: + Span span_; + NameDef* name_def_; + NameDef* config_name_def_; + NameDef* next_name_def_; + std::vector parametric_bindings_; + + Function* config_; + Function* next_; + Function* init_; + std::vector members_; + bool is_public_; +}; + +// Represents a construct to unit test a Proc. Analogous to TestFunction, but +// for Procs. +// +// These are specified with an annotation as follows: +// ```dslx +// #[test_proc()] +// proc test_proc { ... } +// ``` +class TestProc : public AstNode { + public: + TestProc(Module* owner, Proc* proc) : AstNode(owner), proc_(proc) {} + ~TestProc() override; + + AstNodeKind kind() const override { return AstNodeKind::kTestProc; } + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleTestProc(this); + } + std::vector GetChildren(bool want_types) const override { + return {proc_}; + } + std::string_view GetNodeTypeName() const override { return "TestProc"; } + std::string ToString() const override; + + Proc* proc() const { return proc_; } + std::optional GetSpan() const override { return proc_->span(); } + + const std::string& identifier() const { return proc_->identifier(); } + + private: + Proc* proc_; +}; + +} // namespace xls::dslx + +#endif // XLS_DSLX_FRONTEND_PROC_H_ diff --git a/xls/dslx/import_data.h b/xls/dslx/import_data.h index b8c00b8a56..58cf678bdd 100644 --- a/xls/dslx/import_data.h +++ b/xls/dslx/import_data.h @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "xls/dslx/bytecode/bytecode_cache_interface.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/import_record.h" #include "xls/dslx/interp_bindings.h" #include "xls/dslx/type_system/type_info.h" diff --git a/xls/dslx/interp_value_helpers_test.cc b/xls/dslx/interp_value_helpers_test.cc index 30667ae710..212de5bdae 100644 --- a/xls/dslx/interp_value_helpers_test.cc +++ b/xls/dslx/interp_value_helpers_test.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "xls/common/status/matchers.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/type_system/concrete_type.h" diff --git a/xls/dslx/ir_convert/BUILD b/xls/dslx/ir_convert/BUILD index 62b0699d91..7dad3d585c 100644 --- a/xls/dslx/ir_convert/BUILD +++ b/xls/dslx/ir_convert/BUILD @@ -269,9 +269,13 @@ cc_library( "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//xls/common:visitor", + "//xls/common/logging", "//xls/common/status:ret_check", + "//xls/common/status:status_macros", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:builtins_metadata", + "//xls/dslx/frontend:module", + "//xls/dslx/frontend:proc", "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type_info", ], diff --git a/xls/dslx/ir_convert/extract_conversion_order.cc b/xls/dslx/ir_convert/extract_conversion_order.cc index 0b67ede68b..b475db06b8 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.cc +++ b/xls/dslx/ir_convert/extract_conversion_order.cc @@ -31,10 +31,13 @@ #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "xls/common/logging/logging.h" #include "xls/common/status/ret_check.h" +#include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/builtins_metadata.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/type_system/parametric_env.h" namespace xls::dslx { diff --git a/xls/dslx/ir_convert/extract_conversion_order.h b/xls/dslx/ir_convert/extract_conversion_order.h index 6e8fd14bc6..485127b6e8 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.h +++ b/xls/dslx/ir_convert/extract_conversion_order.h @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/proc.h" #include "xls/dslx/type_system/parametric_env.h" #include "xls/dslx/type_system/type_info.h" diff --git a/xls/dslx/lsp/BUILD b/xls/dslx/lsp/BUILD index ef56240fb9..d93ed8a7c2 100644 --- a/xls/dslx/lsp/BUILD +++ b/xls/dslx/lsp/BUILD @@ -32,7 +32,9 @@ cc_library( srcs = ["find_definition.cc"], hdrs = ["find_definition.h"], deps = [ + "//xls/common/logging", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", ], ) @@ -99,6 +101,7 @@ cc_library( "@com_google_absl//absl/types:variant", "//xls/common:visitor", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "@verible//common/lsp:lsp-protocol", "@verible//common/lsp:lsp-protocol-enums", ], diff --git a/xls/dslx/lsp/document_symbols.h b/xls/dslx/lsp/document_symbols.h index 452f686c2e..2f091af405 100644 --- a/xls/dslx/lsp/document_symbols.h +++ b/xls/dslx/lsp/document_symbols.h @@ -19,6 +19,7 @@ #include "external/verible/common/lsp/lsp-protocol.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" namespace xls::dslx { diff --git a/xls/dslx/lsp/find_definition.cc b/xls/dslx/lsp/find_definition.cc index 8ac5f30554..deec45946a 100644 --- a/xls/dslx/lsp/find_definition.cc +++ b/xls/dslx/lsp/find_definition.cc @@ -18,6 +18,10 @@ #include #include +#include "xls/common/logging/logging.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" + namespace xls::dslx { std::optional FindDefinition(const Module& m, const Pos& selected) { diff --git a/xls/dslx/lsp/find_definition.h b/xls/dslx/lsp/find_definition.h index 2ecd4b5205..e6718c49a6 100644 --- a/xls/dslx/lsp/find_definition.h +++ b/xls/dslx/lsp/find_definition.h @@ -18,6 +18,7 @@ #include #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" namespace xls::dslx { diff --git a/xls/dslx/run_comparator.cc b/xls/dslx/run_comparator.cc index 329a3d7be4..39310c11a7 100644 --- a/xls/dslx/run_comparator.cc +++ b/xls/dslx/run_comparator.cc @@ -14,11 +14,26 @@ #include "xls/dslx/run_comparator.h" +#include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xls/common/logging/logging.h" +#include "xls/common/status/ret_check.h" +#include "xls/common/status/status_macros.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/interp_value.h" #include "xls/dslx/mangle.h" +#include "xls/dslx/type_system/parametric_env.h" #include "xls/interpreter/function_interpreter.h" +#include "xls/ir/function.h" +#include "xls/ir/value.h" +#include "xls/jit/function_jit.h" namespace xls::dslx { diff --git a/xls/dslx/type_system/BUILD b/xls/dslx/type_system/BUILD index b33dfce03a..d7c91d39f1 100644 --- a/xls/dslx/type_system/BUILD +++ b/xls/dslx/type_system/BUILD @@ -53,6 +53,8 @@ cc_test( "//xls/common:xls_gunit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/dslx/frontend:module", + "//xls/dslx/frontend:pos", ], ) @@ -248,7 +250,9 @@ cc_library( "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_utils", "//xls/dslx/frontend:builtins_metadata", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", + "//xls/dslx/frontend:proc", "@com_googlesource_code_re2//:re2", ], ) @@ -433,6 +437,7 @@ cc_library( "//xls/common/status:ret_check", "//xls/dslx:interp_value", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", ], ) @@ -445,6 +450,7 @@ cc_test( "//xls/common:xls_gunit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/dslx/frontend:module", ], ) diff --git a/xls/dslx/type_system/concrete_type_test.cc b/xls/dslx/type_system/concrete_type_test.cc index 326dddc230..241b2af2a7 100644 --- a/xls/dslx/type_system/concrete_type_test.cc +++ b/xls/dslx/type_system/concrete_type_test.cc @@ -27,6 +27,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xls/common/status/matchers.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/frontend/pos.h" namespace xls::dslx { namespace { diff --git a/xls/dslx/type_system/type_info.cc b/xls/dslx/type_system/type_info.cc index c35b0bfba3..499884bd52 100644 --- a/xls/dslx/type_system/type_info.cc +++ b/xls/dslx/type_system/type_info.cc @@ -31,6 +31,7 @@ #include "xls/common/logging/logging.h" #include "xls/common/status/ret_check.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/interp_value.h" namespace xls::dslx { diff --git a/xls/dslx/type_system/type_info_test.cc b/xls/dslx/type_system/type_info_test.cc index 845950a420..0c65aafe98 100644 --- a/xls/dslx/type_system/type_info_test.cc +++ b/xls/dslx/type_system/type_info_test.cc @@ -19,6 +19,7 @@ #include "gtest/gtest.h" #include "absl/status/statusor.h" #include "xls/common/status/matchers.h" +#include "xls/dslx/frontend/module.h" namespace xls::dslx { namespace { diff --git a/xls/dslx/type_system/typecheck.cc b/xls/dslx/type_system/typecheck.cc index 2a4824a9a5..13a6f74e66 100644 --- a/xls/dslx/type_system/typecheck.cc +++ b/xls/dslx/type_system/typecheck.cc @@ -48,7 +48,9 @@ #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_utils.h" #include "xls/dslx/frontend/builtins_metadata.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" +#include "xls/dslx/frontend/proc.h" #include "xls/dslx/import_data.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/type_system/concrete_type.h" diff --git a/xls/fuzzer/BUILD b/xls/fuzzer/BUILD index 95020318c8..668870783b 100644 --- a/xls/fuzzer/BUILD +++ b/xls/fuzzer/BUILD @@ -644,6 +644,7 @@ cc_library( "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_cloner", "//xls/dslx/frontend:ast_utils", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/frontend:token", "//xls/ir:bits", @@ -831,6 +832,7 @@ cc_library( "//xls/data_structures:inline_bitmap", "//xls/dslx:interp_value", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/type_system:concrete_type", "//xls/ir:bits", @@ -855,6 +857,7 @@ cc_test( "//xls/common/status:matchers", "//xls/dslx:interp_value", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/type_system:concrete_type", ], diff --git a/xls/fuzzer/ast_generator.h b/xls/fuzzer/ast_generator.h index 6338307726..6da6c1e433 100644 --- a/xls/fuzzer/ast_generator.h +++ b/xls/fuzzer/ast_generator.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include @@ -36,9 +35,9 @@ #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xls/common/logging/logging.h" #include "xls/common/test_macros.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/frontend/token.h" #include "xls/fuzzer/ast_generator_options.pb.h" diff --git a/xls/fuzzer/value_generator.cc b/xls/fuzzer/value_generator.cc index 5f62e058e8..ee48188f8c 100644 --- a/xls/fuzzer/value_generator.cc +++ b/xls/fuzzer/value_generator.cc @@ -34,6 +34,7 @@ #include "xls/common/visitor.h" #include "xls/data_structures/inline_bitmap.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/type_system/concrete_type.h" diff --git a/xls/fuzzer/value_generator_test.cc b/xls/fuzzer/value_generator_test.cc index d73ae85fd0..9a8c40d23e 100644 --- a/xls/fuzzer/value_generator_test.cc +++ b/xls/fuzzer/value_generator_test.cc @@ -23,6 +23,7 @@ #include "gtest/gtest.h" #include "xls/common/status/matchers.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/type_system/concrete_type.h" diff --git a/xls/tools/BUILD b/xls/tools/BUILD index 2858a5aba8..ce6fc4e275 100644 --- a/xls/tools/BUILD +++ b/xls/tools/BUILD @@ -323,6 +323,7 @@ cc_binary( "//xls/common:init_xls", "//xls/common/file:filesystem", "//xls/common/logging", + "//xls/dslx/frontend:module", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", @@ -341,6 +342,7 @@ cc_test( "//xls/common/file:temp_file", "//xls/common/status:matchers", "//xls/common/status:status_macros", + "//xls/dslx/frontend:module", "@com_google_absl//absl/status:statusor", ], ) diff --git a/xls/tools/proto_to_dslx_main.cc b/xls/tools/proto_to_dslx_main.cc index 3337488890..ce6037ca19 100644 --- a/xls/tools/proto_to_dslx_main.cc +++ b/xls/tools/proto_to_dslx_main.cc @@ -23,6 +23,7 @@ #include "xls/common/file/filesystem.h" #include "xls/common/init_xls.h" #include "xls/common/logging/logging.h" +#include "xls/dslx/frontend/module.h" #include "xls/tools/proto_to_dslx.h" ABSL_FLAG(std::string, proto_def_path, "", diff --git a/xls/tools/proto_to_dslx_test.cc b/xls/tools/proto_to_dslx_test.cc index b98004de0b..85895fbfa1 100644 --- a/xls/tools/proto_to_dslx_test.cc +++ b/xls/tools/proto_to_dslx_test.cc @@ -25,6 +25,7 @@ #include "xls/common/file/temp_file.h" #include "xls/common/status/matchers.h" #include "xls/common/status/status_macros.h" +#include "xls/dslx/frontend/module.h" namespace xls { namespace {