diff --git a/toolchain/check/import_ref.cpp b/toolchain/check/import_ref.cpp index 795e915192fc..d78b438c22bc 100644 --- a/toolchain/check/import_ref.cpp +++ b/toolchain/check/import_ref.cpp @@ -328,12 +328,6 @@ class ImportRefResolver { llvm::SmallVector indirect_insts = {}; }; - // Local information associated with an imported parameter. - struct ParamData { - SemIR::ConstantId type_const_id; - SemIR::ConstantId bind_const_id; - }; - // Local information associated with an imported generic. struct GenericData { llvm::SmallVector bindings; @@ -469,6 +463,13 @@ class ImportRefResolver { return GetLocalConstantId(import_ir_.types().GetConstantId(type_id)); } + template + auto GetLocalConstantIdChecked(Id id) { + auto result = GetLocalConstantId(id); + CARBON_CHECK(result.is_valid()); + return result; + } + // Gets the local constant values corresponding to an imported inst block. auto GetLocalInstBlockContents(SemIR::InstBlockId import_block_id) -> llvm::SmallVector { @@ -643,21 +644,16 @@ class ImportRefResolver { return specific_id; } - // Returns the ConstantId for each parameter's type. Adds unresolved constants - // to work_stack_. - auto GetLocalParamConstantIds(SemIR::InstBlockId param_refs_id) - -> llvm::SmallVector { - llvm::SmallVector param_data; + // Adds unresolved constants for each parameter's type to work_stack_. + auto LoadLocalParamConstantIds(SemIR::InstBlockId param_refs_id) -> void { if (!param_refs_id.is_valid() || param_refs_id == SemIR::InstBlockId::Empty) { - return param_data; + return; } const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id); - param_data.reserve(param_refs.size()); for (auto inst_id : param_refs) { - auto type_const_id = - GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id()); + GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id()); // If the parameter is a symbolic binding, build the BindSymbolicName // constant. @@ -667,19 +663,25 @@ class ImportRefResolver { bind_id = addr->inner_id; bind_inst = import_ir_.insts().Get(bind_id); } - auto bind_const_id = bind_inst.Is() - ? GetLocalConstantId(bind_id) - : SemIR::ConstantId::Invalid; - param_data.push_back( - {.type_const_id = type_const_id, .bind_const_id = bind_const_id}); + if (bind_inst.Is()) { + GetLocalConstantId(bind_id); + } } - return param_data; } - // Given a param_refs_id and const_ids from GetLocalParamConstantIds, returns - // a version of param_refs_id localized to the current IR. - auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id, - const llvm::SmallVector& params_data) + // Returns a version of param_refs_id localized to the current IR. + // + // Must only be called after a call to GetLocalParamConstantIds(param_refs_id) + // has completed without adding any new work to work_stack_. + // + // TODO: This is inconsistent with the rest of this class, which expects + // the relevant constants to be explicitly passed in. That makes it + // easier to statically detect when an input isn't loaded, but makes it + // harder to support importing more complex inst structures. We should + // take a holistic look at how to balance those concerns. For example, + // could the same function be used to load the constants and use them, with + // a parameter to select between the two? + auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id) -> SemIR::InstBlockId { if (!param_refs_id.is_valid() || param_refs_id == SemIR::InstBlockId::Empty) { @@ -687,7 +689,7 @@ class ImportRefResolver { } const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id); llvm::SmallVector new_param_refs; - for (auto [ref_id, param_data] : llvm::zip(param_refs, params_data)) { + for (auto ref_id : param_refs) { // Figure out the param structure. This echoes // Function::GetParamFromParamRefId. // TODO: Consider a different parameter handling to simplify import logic. @@ -712,8 +714,8 @@ class ImportRefResolver { // Rebuild the param instruction. auto name_id = GetLocalNameId(param_inst.name_id); - auto type_id = - context_.GetTypeIdForTypeConstant(param_data.type_const_id); + auto type_id = context_.GetTypeIdForTypeConstant( + GetLocalConstantIdChecked(param_inst.type_id)); auto new_param_id = context_.AddInstInNoBlock( AddImportIRInst(param_id), @@ -740,12 +742,12 @@ class ImportRefResolver { auto new_bind_inst = context_.insts().GetAs( context_.constant_values().GetInstId( - param_data.bind_const_id)); + GetLocalConstantIdChecked(bind_id))); new_bind_inst.value_id = new_param_id; new_param_id = context_.AddInstInNoBlock(AddImportIRInst(bind_id), new_bind_inst); context_.constant_values().Set(new_param_id, - param_data.bind_const_id); + GetLocalConstantIdChecked(bind_id)); break; } default: { @@ -1322,9 +1324,8 @@ class ImportRefResolver { // Load constants for the definition. auto parent_scope_id = GetLocalNameScopeId(import_class.parent_scope_id); - auto implicit_param_const_ids = - GetLocalParamConstantIds(import_class.implicit_param_refs_id); - auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id); + LoadLocalParamConstantIds(import_class.implicit_param_refs_id); + LoadLocalParamConstantIds(import_class.param_refs_id); auto generic_data = GetLocalGenericData(import_class.generic_id); auto self_const_id = GetLocalConstantId(import_class.self_type_id); auto complete_type_witness_id = @@ -1341,10 +1342,9 @@ class ImportRefResolver { auto& new_class = context_.classes().Get(class_id); new_class.parent_scope_id = parent_scope_id; - new_class.implicit_param_refs_id = GetLocalParamRefsId( - import_class.implicit_param_refs_id, implicit_param_const_ids); - new_class.param_refs_id = - GetLocalParamRefsId(import_class.param_refs_id, param_const_ids); + new_class.implicit_param_refs_id = + GetLocalParamRefsId(import_class.implicit_param_refs_id); + new_class.param_refs_id = GetLocalParamRefsId(import_class.param_refs_id); SetGenericData(import_class.generic_id, new_class.generic_id, generic_data); new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id); @@ -1495,10 +1495,8 @@ class ImportRefResolver { import_ir_.insts().Get(import_function.return_storage_id).type_id()); } auto parent_scope_id = GetLocalNameScopeId(import_function.parent_scope_id); - auto implicit_param_const_ids = - GetLocalParamConstantIds(import_function.implicit_param_refs_id); - auto param_const_ids = - GetLocalParamConstantIds(import_function.param_refs_id); + LoadLocalParamConstantIds(import_function.implicit_param_refs_id); + LoadLocalParamConstantIds(import_function.param_refs_id); auto generic_data = GetLocalGenericData(import_function.generic_id); if (HasNewWork()) { @@ -1508,10 +1506,10 @@ class ImportRefResolver { // Add the function declaration. auto& new_function = context_.functions().Get(function_id); new_function.parent_scope_id = parent_scope_id; - new_function.implicit_param_refs_id = GetLocalParamRefsId( - import_function.implicit_param_refs_id, implicit_param_const_ids); + new_function.implicit_param_refs_id = + GetLocalParamRefsId(import_function.implicit_param_refs_id); new_function.param_refs_id = - GetLocalParamRefsId(import_function.param_refs_id, param_const_ids); + GetLocalParamRefsId(import_function.param_refs_id); SetGenericData(import_function.generic_id, new_function.generic_id, generic_data); @@ -1654,8 +1652,7 @@ class ImportRefResolver { // Load constants for the definition. auto parent_scope_id = GetLocalNameScopeId(import_impl.parent_scope_id); - auto implicit_param_const_ids = - GetLocalParamConstantIds(import_impl.implicit_param_refs_id); + LoadLocalParamConstantIds(import_impl.implicit_param_refs_id); auto generic_data = GetLocalGenericData(import_impl.generic_id); auto self_const_id = GetLocalConstantId(import_impl.self_id); auto constraint_const_id = GetLocalConstantId(import_impl.constraint_id); @@ -1666,8 +1663,8 @@ class ImportRefResolver { auto& new_impl = context_.impls().Get(impl_id); new_impl.parent_scope_id = parent_scope_id; - new_impl.implicit_param_refs_id = GetLocalParamRefsId( - import_impl.implicit_param_refs_id, implicit_param_const_ids); + new_impl.implicit_param_refs_id = + GetLocalParamRefsId(import_impl.implicit_param_refs_id); CARBON_CHECK(!import_impl.param_refs_id.is_valid() && !new_impl.param_refs_id.is_valid()); SetGenericData(import_impl.generic_id, new_impl.generic_id, generic_data); @@ -1811,10 +1808,8 @@ class ImportRefResolver { auto parent_scope_id = GetLocalNameScopeId(import_interface.parent_scope_id); - auto implicit_param_const_ids = - GetLocalParamConstantIds(import_interface.implicit_param_refs_id); - auto param_const_ids = - GetLocalParamConstantIds(import_interface.param_refs_id); + LoadLocalParamConstantIds(import_interface.implicit_param_refs_id); + LoadLocalParamConstantIds(import_interface.param_refs_id); auto generic_data = GetLocalGenericData(import_interface.generic_id); std::optional self_param_id; @@ -1828,10 +1823,10 @@ class ImportRefResolver { auto& new_interface = context_.interfaces().Get(interface_id); new_interface.parent_scope_id = parent_scope_id; - new_interface.implicit_param_refs_id = GetLocalParamRefsId( - import_interface.implicit_param_refs_id, implicit_param_const_ids); + new_interface.implicit_param_refs_id = + GetLocalParamRefsId(import_interface.implicit_param_refs_id); new_interface.param_refs_id = - GetLocalParamRefsId(import_interface.param_refs_id, param_const_ids); + GetLocalParamRefsId(import_interface.param_refs_id); SetGenericData(import_interface.generic_id, new_interface.generic_id, generic_data);