From d8a8ca0cb6ac4ea1436d44450d4856be10395f45 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Fri, 4 Oct 2024 16:57:12 -0700 Subject: [PATCH] Multi-block eval_proc_main support PiperOrigin-RevId: 682496973 --- xls/examples/dslx_module/BUILD | 42 ++++++++ xls/ir/block.cc | 9 ++ xls/ir/block.h | 2 + xls/ir/function_builder_test.cc | 11 +++ xls/tools/BUILD | 6 ++ xls/tools/eval_proc_main.cc | 67 +++++++++++-- xls/tools/eval_proc_main_test.py | 164 +++++++++++++++++++++++++++++++ 7 files changed, 291 insertions(+), 10 deletions(-) diff --git a/xls/examples/dslx_module/BUILD b/xls/examples/dslx_module/BUILD index 2a7e3be706..39c2dd5db8 100644 --- a/xls/examples/dslx_module/BUILD +++ b/xls/examples/dslx_module/BUILD @@ -22,6 +22,7 @@ load( "xls_dslx_library", "xls_dslx_opt_ir", "xls_dslx_test", + "xls_ir_verilog", ) package( @@ -63,6 +64,12 @@ xls_dslx_library( xls_dslx_ir( name = "some_caps_streaming_configured_ir", dslx_top = "some_caps_specialized", + ir_conv_args = { + # Set fifo config for multi-proc codegen. + # It needs to know what configuration to use for the fifo. + "default_fifo_config": "depth: 4, bypass: true, " + + "register_push_outputs: false, register_pop_outputs: false", + }, ir_file = "some_caps_streaming_configured.ir", library = ":some_caps_streaming_configured", ) @@ -83,11 +90,46 @@ xls_dslx_opt_ir( # Note: The optimized ir has different top since the channels are used to determine proc # liveness instead of spawn tree. This is done to avoid having to deal with mangled names. dslx_top = "manual_chan_caps_specialized", + ir_conv_args = { + # Set fifo config for multi-proc codegen. + # It needs to know what configuration to use for the fifo. + "default_fifo_config": "depth: 4, bypass: true, " + + "register_push_outputs: false, register_pop_outputs: false", + }, ir_file = "manual_chan_caps_streaming_configured.ir", library = ":some_caps_streaming_configured", visibility = ["//xls:xls_internal"], ) +xls_ir_verilog( + name = "manual_chan_caps_streaming_configured_multiproc_verilog", + src = ":manaul_chan_caps_streaming_configured_opt_ir", + block_ir_file = "manual_chan_caps_streaming_configured_multiproc.block.ir", + codegen_args = { + "module_name": "manual_chan_caps_streaming", + "generator": "pipeline", + "pipeline_stages": "4", + "delay_model": "unit", + "reset": "rst", + "reset_data_path": "false", + "reset_active_low": "false", + "reset_asynchronous": "false", + "flop_inputs": "false", + "flop_single_value_channels": "false", + "flop_outputs": "false", + "add_idle_output": "false", + "streaming_channel_data_suffix": "_data", + "streaming_channel_ready_suffix": "_ready", + "streaming_channel_valid_suffix": "_valid", + "use_system_verilog": "true", + "assert_format": "\\;", + "multi_proc": "true", + }, + module_sig_file = "manual_chan_caps_streaming_configured_multiproc.sig.textproto", + verilog_file = "manual_chan_caps_streaming_configured_multiproc.sv", + visibility = ["//xls:xls_internal"], +) + cc_xls_ir_jit_wrapper( name = "some_caps_opt_jit_wrapper", src = ":manaul_chan_caps_streaming_configured_opt_ir", diff --git a/xls/ir/block.cc b/xls/ir/block.cc index 0b801d40e9..91f594261f 100644 --- a/xls/ir/block.cc +++ b/xls/ir/block.cc @@ -284,6 +284,15 @@ absl::Status Block::SetPortNameExactly(std::string_view name, Node* node) { return absl::OkStatus(); } +bool Block::HasInputPort(std::string_view name) const { + return ports_by_name_.contains(name) && + std::holds_alternative(ports_by_name_.at(name)); +} +bool Block::HasOutputPort(std::string_view name) const { + return ports_by_name_.contains(name) && + std::holds_alternative(ports_by_name_.at(name)); +} + absl::StatusOr Block::GetInputPort(std::string_view name) const { auto port_iter = ports_by_name_.find(name); if (port_iter == ports_by_name_.end()) { diff --git a/xls/ir/block.h b/xls/ir/block.h index 48f2d02f78..56b63b13a3 100644 --- a/xls/ir/block.h +++ b/xls/ir/block.h @@ -78,7 +78,9 @@ class Block : public FunctionBase { // Returns a given input/output port by name. absl::StatusOr GetInputPort(std::string_view name) const; + bool HasInputPort(std::string_view name) const; absl::StatusOr GetOutputPort(std::string_view name) const; + bool HasOutputPort(std::string_view name) const; // Adds an input/output port to the block. These methods should be used to add // ports rather than FunctionBase::AddNode and FunctionBase::MakeNode (checked diff --git a/xls/ir/function_builder_test.cc b/xls/ir/function_builder_test.cc index 7146e67304..21252674c8 100644 --- a/xls/ir/function_builder_test.cc +++ b/xls/ir/function_builder_test.cc @@ -1109,6 +1109,17 @@ TEST(FunctionBuilderTest, NaryBitwiseAnd) { m::And(m::Param("a"), m::Param("b"), m::Param("c"))); } +TEST(FunctionBuilderTest, Ports) { + Package p("p"); + BlockBuilder b("b", &p); + b.OutputPort("bar", b.InputPort("foo", p.GetBitsType(32))); + XLS_ASSERT_OK_AND_ASSIGN(Block * blk, b.Build()); + EXPECT_TRUE(blk->HasInputPort("foo")); + EXPECT_FALSE(blk->HasOutputPort("foo")); + EXPECT_TRUE(blk->HasOutputPort("bar")); + EXPECT_FALSE(blk->HasInputPort("bar")); +} + TEST(FunctionBuilderTest, Registers) { Package p("p"); BlockBuilder b("b", &p); diff --git a/xls/tools/BUILD b/xls/tools/BUILD index b2e6e14709..b5b9a41645 100644 --- a/xls/tools/BUILD +++ b/xls/tools/BUILD @@ -298,6 +298,7 @@ cc_binary( "//xls/interpreter:serial_proc_runtime", "//xls/ir", "//xls/ir:bits", + "//xls/ir:block_elaboration", "//xls/ir:channel", "//xls/ir:channel_cc_proto", "//xls/ir:events", @@ -792,11 +793,16 @@ py_test( "//xls/tools/testdata:eval_proc_main_zero_size_test.block.ir", "//xls/tools/testdata:eval_proc_main_zero_size_test.sig.textproto", ":eval_proc_main", + "//xls/examples:delay.block.ir", + "//xls/examples:delay.sig.textproto", + "//xls/examples/dslx_module:manual_chan_caps_streaming_configured_multiproc.block.ir", + "//xls/examples/dslx_module:manual_chan_caps_streaming_configured_multiproc.sig.textproto", ], python_version = "PY3", srcs_version = "PY3", deps = [ ":node_coverage_stats_py_pb2", + ":proc_channel_values_py_pb2", "//xls/common:runfiles", "//xls/ir:xls_value_py_pb2", "@com_google_absl_py//absl/logging", diff --git a/xls/tools/eval_proc_main.cc b/xls/tools/eval_proc_main.cc index 2d53991dec..0b01445dea 100644 --- a/xls/tools/eval_proc_main.cc +++ b/xls/tools/eval_proc_main.cc @@ -42,6 +42,7 @@ #include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -62,6 +63,8 @@ #include "xls/interpreter/interpreter_proc_runtime.h" #include "xls/interpreter/serial_proc_runtime.h" #include "xls/ir/bits.h" +#include "xls/ir/block.h" +#include "xls/ir/block_elaboration.h" #include "xls/ir/channel.h" #include "xls/ir/channel.pb.h" #include "xls/ir/events.h" @@ -69,6 +72,7 @@ #include "xls/ir/ir_parser.h" #include "xls/ir/nodes.h" #include "xls/ir/package.h" +#include "xls/ir/proc.h" #include "xls/ir/register.h" #include "xls/ir/value.h" #include "xls/ir/value_utils.h" @@ -87,6 +91,12 @@ value of each proc will be printed to the terminal upon completion. Initial states are set according to their declarations inside the IR itself. )"; +ABSL_FLAG( + std::optional, top, std::nullopt, + "If present the top construct to simulate. Must be an exact match to " + "the name of an appropriate proc/block. Until new-style-procs are " + "available this is mostly just to support module-name for block " + "simulations as the specified top must be the actual top of the design."); ABSL_FLAG(std::vector, ticks, {}, "Can be a comma-separated list of runs. " "Number of clock ticks to execute for each, with proc state " @@ -207,6 +217,7 @@ struct EvaluateProcsOptions { bool use_jit = false; bool fail_on_assert = false; std::vector ticks = {-1}; + std::optional top = std::nullopt; }; static absl::Status EvaluateProcs( @@ -222,6 +233,13 @@ static absl::Status EvaluateProcs( bool uses_observers = absl::GetFlag(FLAGS_output_node_coverage_stats_proto).has_value() || absl::GetFlag(FLAGS_output_node_coverage_stats_textproto).has_value(); + if (options.top) { + XLS_ASSIGN_OR_RETURN(Proc * proc, package->GetProc(*options.top)); + if (proc != package->GetTop()) { + return absl::UnimplementedError( + "Simulating subsets of the proc network is not implemented yet."); + } + } evaluator_options.set_support_observers(uses_observers); if (options.use_jit) { XLS_ASSIGN_OR_RETURN( @@ -674,6 +692,7 @@ struct RunBlockOptions { bool use_jit = false; std::vector ticks = {-1}; int64_t max_cycles_no_output = 100; + std::optional top; int random_seed; double prob_input_valid_assert; bool show_trace; @@ -734,15 +753,31 @@ static absl::Status RunBlock( const absl::flat_hash_map>& model_memories_param, std::string_view output_stats_path, const RunBlockOptions& options = {}) { - if (package->blocks().size() != 1) { + Block* block; + if (options.top) { + XLS_ASSIGN_OR_RETURN(block, package->GetBlock(*options.top)); + } else if (package->HasTop()) { + if (package->GetTop().value()->IsBlock()) { + XLS_ASSIGN_OR_RETURN(block, package->GetTopAsBlock()); + } else if (package->blocks().size() == 1) { + block = package->blocks().front().get(); + } else { + // This is result of codegen-ing a proc so use the block for the top proc + // as top. + XLS_ASSIGN_OR_RETURN(Proc * top_proc, package->GetTopAsProc()); + XLS_ASSIGN_OR_RETURN( + block, package->GetBlock(top_proc->name()), + _ << "Unable to determine top. Pass --top to select one manually."); + } + } else if (package->blocks().size() == 1) { + block = package->blocks().front().get(); + } else { return absl::InvalidArgumentError( - "Input IR should contain exactly one block"); + "Input IR should contain exactly one block or a top"); } std::mt19937_64 bit_gen(options.random_seed); - Block* block = package->blocks()[0].get(); - // TODO: Support multiple resets CHECK_EQ(options.ticks.size(), 1); @@ -777,13 +812,23 @@ static absl::Status RunBlock( /*read_disabled_value=*/XsOfType(port->GetType()), options.show_trace); } - // Initial register state is one for all registers. - // Ideally this would be randomized, but at least 1s are more likely to - // expose bad behavior than 0s. absl::flat_hash_map reg_state; - for (Register* reg : block->GetRegisters()) { - Value def = ZeroOfType(reg->type()); - reg_state[reg->name()] = XsOfType(reg->type()); + { + XLS_ASSIGN_OR_RETURN(BlockElaboration elab, + BlockElaboration::Elaborate(block)); + for (BlockInstance* inst : elab.instances()) { + if (!inst->block()) { + // Actually a fifo or something without real registers. + continue; + } + for (Register* reg : (*inst->block())->GetRegisters()) { + // Initial register state is one for all registers. + // Ideally this would be randomized, but at least 1s are more likely to + // expose bad behavior than 0s. + reg_state[absl::StrCat(inst->RegisterPrefix(), reg->name())] = + XsOfType(reg->type()); + } + } } bool needs_observer = @@ -1187,6 +1232,7 @@ static absl::Status RealMain( RunBlockOptions block_options = { .ticks = ticks, .max_cycles_no_output = max_cycles_no_output, + .top = absl::GetFlag(FLAGS_top), .random_seed = random_seed, .prob_input_valid_assert = prob_input_valid_assert, .show_trace = show_trace, @@ -1209,6 +1255,7 @@ static absl::Status RealMain( EvaluateProcsOptions evaluate_procs_options = { .fail_on_assert = fail_on_assert, .ticks = ticks, + .top = absl::GetFlag(FLAGS_top), }; if (backend == "serial_jit") { diff --git a/xls/tools/eval_proc_main_test.py b/xls/tools/eval_proc_main_test.py index d93150823c..ea79babc22 100644 --- a/xls/tools/eval_proc_main_test.py +++ b/xls/tools/eval_proc_main_test.py @@ -24,6 +24,7 @@ from xls.common import runfiles from xls.ir import xls_value_pb2 from xls.tools import node_coverage_stats_pb2 +from xls.tools import proc_channel_values_pb2 EVAL_PROC_MAIN_PATH = runfiles.get_path("xls/tools/eval_proc_main") @@ -256,6 +257,69 @@ } """ +MULTI_BLOCK_IR_FILE = runfiles.get_path( + "xls/examples/dslx_module/manual_chan_caps_streaming_configured_multiproc.block.ir" +) +MULTI_BLOCK_SIG_FILE = runfiles.get_path( + "xls/examples/dslx_module/manual_chan_caps_streaming_configured_multiproc.sig.textproto" +) +MULTI_BLOCK_MEMORY_IR_FILE = runfiles.get_path("xls/examples/delay.block.ir") +MULTI_BLOCK_MEMORY_SIG_FILE = runfiles.get_path( + "xls/examples/delay.sig.textproto" +) + + +def _eight_chars(val: bytes) -> xls_value_pb2.ValueProto: + assert len(val) == 8 + return xls_value_pb2.ValueProto( + array=xls_value_pb2.ValueProto.Array( + elements=[ + xls_value_pb2.ValueProto( + bits=xls_value_pb2.ValueProto.Bits( + bit_count=8, data=bytes([v]) + ) + ) + for v in val + ] + ) + ) + + +MULTI_BLOCK_INPUT_CHANNEL_VALUES = ( + proc_channel_values_pb2.ProcChannelValuesProto( + channels=[ + proc_channel_values_pb2.ProcChannelValuesProto.Channel( + name="some_caps_streaming_configured__external_input_wire", + entry=[ + _eight_chars(b"abcdabcd"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"abcdabcd"), + ], + ) + ] + ) +) +MULTI_BLOCK_OUTPUT_CHANNEL_VALUES = ( + proc_channel_values_pb2.ProcChannelValuesProto( + channels=[ + proc_channel_values_pb2.ProcChannelValuesProto.Channel( + name="some_caps_streaming_configured__external_output_wire", + entry=[ + _eight_chars(b"ABCDABCD"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"AbCdAbCd"), + _eight_chars(b"ABCDABCD"), + _eight_chars(b"abcdabcd"), + _eight_chars(b"AbCdAbCd"), + ], + ) + ] + ) +) + TOKEN = xls_value_pb2.ValueProto(token=xls_value_pb2.ValueProto.Token()) _ONE_BIT_TRUE = xls_value_pb2.ValueProto( bits=xls_value_pb2.ValueProto.Bits(bit_count=1, data=b"\1") @@ -836,6 +900,59 @@ def test_block_memory(self, backend): "Memory Model: Initiated read mem[3] = bits[32]:6", output.stderr ) + @parameterized_block_backends + def test_multi_block_memory(self, backend): + tick_count = 3 * 2048 + ir_file = MULTI_BLOCK_MEMORY_IR_FILE + signature_file = MULTI_BLOCK_MEMORY_SIG_FILE + input_channel = proc_channel_values_pb2.ProcChannelValuesProto.Channel( + name="delay__data_in" + ) + output_channel = proc_channel_values_pb2.ProcChannelValuesProto.Channel( + name="delay__data_out" + ) + # Make a little oracle to get the results. + buffer = [3] * 2048 + for t in range(tick_count): + input_channel.entry.append(_value_32_bits(t)) + buffer.append(t) + output_channel.entry.append(_value_32_bits(buffer.pop(0))) + + # Create input and output args + input_data = proc_channel_values_pb2.ProcChannelValuesProto( + channels=[input_channel] + ) + output_data = proc_channel_values_pb2.ProcChannelValuesProto( + channels=[output_channel] + ) + channels_in_ir_file = self.create_tempfile( + content=input_data.SerializeToString() + ) + channels_out_ir_file = self.create_tempfile( + content=output_data.SerializeToString() + ) + + shared_args = [ + EVAL_PROC_MAIN_PATH, + ir_file, + "--top=delay_top", + "--proto_inputs_for_all_channels", + channels_in_ir_file, + "--expected_proto_outputs_for_all_channels", + channels_out_ir_file, + "--block_signature_proto", + signature_file, + "--model_memories", + "ram=1024/(bits[64]:0)", + "--alsologtostderr", + "--show_trace", + "--ticks", + "-1", + # f"{tick_count + 1}", + ] + backend + + run_command(shared_args) + @parameterized_block_backends def test_observe_block(self, backend): ir_file = self.create_tempfile(content=OBSERVER_IR) @@ -952,6 +1069,53 @@ def test_observe_proc(self, backend): ) self.assertLen(node_coverage.nodes, 7) + @parameterized_proc_backends + def test_multi_proc(self, backend): + ir_file = MULTI_BLOCK_IR_FILE + channels_in_file = self.create_tempfile( + content=MULTI_BLOCK_INPUT_CHANNEL_VALUES.SerializeToString() + ) + channels_out_file = self.create_tempfile( + content=MULTI_BLOCK_OUTPUT_CHANNEL_VALUES.SerializeToString() + ) + run_command( + [ + EVAL_PROC_MAIN_PATH, + ir_file, + f"--proto_inputs_for_all_channels={channels_in_file.full_path}", + f"--expected_proto_outputs_for_all_channels={channels_out_file.full_path}", + "--alsologtostderr", + "--show_trace", + "--ticks=6", + ] + + backend + ) + + @parameterized_block_backends + def test_multi_block(self, backend): + ir_file = MULTI_BLOCK_IR_FILE + sig_file = MULTI_BLOCK_SIG_FILE + channels_in_file = self.create_tempfile( + content=MULTI_BLOCK_INPUT_CHANNEL_VALUES.SerializeToString() + ) + channels_out_file = self.create_tempfile( + content=MULTI_BLOCK_OUTPUT_CHANNEL_VALUES.SerializeToString() + ) + run_command( + [ + EVAL_PROC_MAIN_PATH, + ir_file, + f"--block_signature_proto={sig_file}", + "--top=manual_chan_caps_streaming", + f"--proto_inputs_for_all_channels={channels_in_file.full_path}", + f"--expected_proto_outputs_for_all_channels={channels_out_file.full_path}", + "--alsologtostderr", + "--show_trace", + "--ticks=6", + ] + + backend + ) + @parameterized_proc_backends def test_zero_size_proc(self, backend): input_file = self.create_tempfile(content=textwrap.dedent("""