Skip to content

Commit

Permalink
Remove parameter-constant arrays from import_ref (carbon-language#4360)
Browse files Browse the repository at this point in the history
This makes the code more resilient to changes in the structure of
parameter insts, and could help avoid bugs by making the
ImportRefResolver's data structures the single source of truth.
  • Loading branch information
geoffromer authored Oct 3, 2024
1 parent afbea6a commit e617d64
Showing 1 changed file with 50 additions and 55 deletions.
105 changes: 50 additions & 55 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,6 @@ class ImportRefResolver {
llvm::SmallVector<SemIR::ImportIRInst> indirect_insts = {};
};

// Local information associated with an imported parameter.
struct ParamData {
SemIR::ConstantId type_const_id;
SemIR::ConstantId bind_const_id;
};

// Local information associated with an imported generic.
struct GenericData {
llvm::SmallVector<SemIR::InstId> bindings;
Expand Down Expand Up @@ -469,6 +463,13 @@ class ImportRefResolver {
return GetLocalConstantId(import_ir_.types().GetConstantId(type_id));
}

template <typename Id>
auto GetLocalConstantIdChecked(Id id) {
auto result = GetLocalConstantId(id);
CARBON_CHECK(result.is_valid());
return result;
}

// Gets the local constant values corresponding to an imported inst block.
auto GetLocalInstBlockContents(SemIR::InstBlockId import_block_id)
-> llvm::SmallVector<SemIR::InstId> {
Expand Down Expand Up @@ -643,21 +644,16 @@ class ImportRefResolver {
return specific_id;
}

// Returns the ConstantId for each parameter's type. Adds unresolved constants
// to work_stack_.
auto GetLocalParamConstantIds(SemIR::InstBlockId param_refs_id)
-> llvm::SmallVector<ParamData> {
llvm::SmallVector<ParamData> param_data;
// Adds unresolved constants for each parameter's type to work_stack_.
auto LoadLocalParamConstantIds(SemIR::InstBlockId param_refs_id) -> void {
if (!param_refs_id.is_valid() ||
param_refs_id == SemIR::InstBlockId::Empty) {
return param_data;
return;
}

const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
param_data.reserve(param_refs.size());
for (auto inst_id : param_refs) {
auto type_const_id =
GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());
GetLocalConstantId(import_ir_.insts().Get(inst_id).type_id());

// If the parameter is a symbolic binding, build the BindSymbolicName
// constant.
Expand All @@ -667,27 +663,33 @@ class ImportRefResolver {
bind_id = addr->inner_id;
bind_inst = import_ir_.insts().Get(bind_id);
}
auto bind_const_id = bind_inst.Is<SemIR::BindSymbolicName>()
? GetLocalConstantId(bind_id)
: SemIR::ConstantId::Invalid;
param_data.push_back(
{.type_const_id = type_const_id, .bind_const_id = bind_const_id});
if (bind_inst.Is<SemIR::BindSymbolicName>()) {
GetLocalConstantId(bind_id);
}
}
return param_data;
}

// Given a param_refs_id and const_ids from GetLocalParamConstantIds, returns
// a version of param_refs_id localized to the current IR.
auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id,
const llvm::SmallVector<ParamData>& params_data)
// Returns a version of param_refs_id localized to the current IR.
//
// Must only be called after a call to GetLocalParamConstantIds(param_refs_id)
// has completed without adding any new work to work_stack_.
//
// TODO: This is inconsistent with the rest of this class, which expects
// the relevant constants to be explicitly passed in. That makes it
// easier to statically detect when an input isn't loaded, but makes it
// harder to support importing more complex inst structures. We should
// take a holistic look at how to balance those concerns. For example,
// could the same function be used to load the constants and use them, with
// a parameter to select between the two?
auto GetLocalParamRefsId(SemIR::InstBlockId param_refs_id)
-> SemIR::InstBlockId {
if (!param_refs_id.is_valid() ||
param_refs_id == SemIR::InstBlockId::Empty) {
return param_refs_id;
}
const auto& param_refs = import_ir_.inst_blocks().Get(param_refs_id);
llvm::SmallVector<SemIR::InstId> new_param_refs;
for (auto [ref_id, param_data] : llvm::zip(param_refs, params_data)) {
for (auto ref_id : param_refs) {
// Figure out the param structure. This echoes
// Function::GetParamFromParamRefId.
// TODO: Consider a different parameter handling to simplify import logic.
Expand All @@ -712,8 +714,8 @@ class ImportRefResolver {

// Rebuild the param instruction.
auto name_id = GetLocalNameId(param_inst.name_id);
auto type_id =
context_.GetTypeIdForTypeConstant(param_data.type_const_id);
auto type_id = context_.GetTypeIdForTypeConstant(
GetLocalConstantIdChecked(param_inst.type_id));

auto new_param_id = context_.AddInstInNoBlock<SemIR::Param>(
AddImportIRInst(param_id),
Expand All @@ -740,12 +742,12 @@ class ImportRefResolver {
auto new_bind_inst =
context_.insts().GetAs<SemIR::BindSymbolicName>(
context_.constant_values().GetInstId(
param_data.bind_const_id));
GetLocalConstantIdChecked(bind_id)));
new_bind_inst.value_id = new_param_id;
new_param_id = context_.AddInstInNoBlock(AddImportIRInst(bind_id),
new_bind_inst);
context_.constant_values().Set(new_param_id,
param_data.bind_const_id);
GetLocalConstantIdChecked(bind_id));
break;
}
default: {
Expand Down Expand Up @@ -1322,9 +1324,8 @@ class ImportRefResolver {

// Load constants for the definition.
auto parent_scope_id = GetLocalNameScopeId(import_class.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_class.implicit_param_refs_id);
auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id);
LoadLocalParamConstantIds(import_class.implicit_param_refs_id);
LoadLocalParamConstantIds(import_class.param_refs_id);
auto generic_data = GetLocalGenericData(import_class.generic_id);
auto self_const_id = GetLocalConstantId(import_class.self_type_id);
auto complete_type_witness_id =
Expand All @@ -1341,10 +1342,9 @@ class ImportRefResolver {

auto& new_class = context_.classes().Get(class_id);
new_class.parent_scope_id = parent_scope_id;
new_class.implicit_param_refs_id = GetLocalParamRefsId(
import_class.implicit_param_refs_id, implicit_param_const_ids);
new_class.param_refs_id =
GetLocalParamRefsId(import_class.param_refs_id, param_const_ids);
new_class.implicit_param_refs_id =
GetLocalParamRefsId(import_class.implicit_param_refs_id);
new_class.param_refs_id = GetLocalParamRefsId(import_class.param_refs_id);
SetGenericData(import_class.generic_id, new_class.generic_id, generic_data);
new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id);

Expand Down Expand Up @@ -1495,10 +1495,8 @@ class ImportRefResolver {
import_ir_.insts().Get(import_function.return_storage_id).type_id());
}
auto parent_scope_id = GetLocalNameScopeId(import_function.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_function.implicit_param_refs_id);
auto param_const_ids =
GetLocalParamConstantIds(import_function.param_refs_id);
LoadLocalParamConstantIds(import_function.implicit_param_refs_id);
LoadLocalParamConstantIds(import_function.param_refs_id);
auto generic_data = GetLocalGenericData(import_function.generic_id);

if (HasNewWork()) {
Expand All @@ -1508,10 +1506,10 @@ class ImportRefResolver {
// Add the function declaration.
auto& new_function = context_.functions().Get(function_id);
new_function.parent_scope_id = parent_scope_id;
new_function.implicit_param_refs_id = GetLocalParamRefsId(
import_function.implicit_param_refs_id, implicit_param_const_ids);
new_function.implicit_param_refs_id =
GetLocalParamRefsId(import_function.implicit_param_refs_id);
new_function.param_refs_id =
GetLocalParamRefsId(import_function.param_refs_id, param_const_ids);
GetLocalParamRefsId(import_function.param_refs_id);
SetGenericData(import_function.generic_id, new_function.generic_id,
generic_data);

Expand Down Expand Up @@ -1654,8 +1652,7 @@ class ImportRefResolver {

// Load constants for the definition.
auto parent_scope_id = GetLocalNameScopeId(import_impl.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_impl.implicit_param_refs_id);
LoadLocalParamConstantIds(import_impl.implicit_param_refs_id);
auto generic_data = GetLocalGenericData(import_impl.generic_id);
auto self_const_id = GetLocalConstantId(import_impl.self_id);
auto constraint_const_id = GetLocalConstantId(import_impl.constraint_id);
Expand All @@ -1666,8 +1663,8 @@ class ImportRefResolver {

auto& new_impl = context_.impls().Get(impl_id);
new_impl.parent_scope_id = parent_scope_id;
new_impl.implicit_param_refs_id = GetLocalParamRefsId(
import_impl.implicit_param_refs_id, implicit_param_const_ids);
new_impl.implicit_param_refs_id =
GetLocalParamRefsId(import_impl.implicit_param_refs_id);
CARBON_CHECK(!import_impl.param_refs_id.is_valid() &&
!new_impl.param_refs_id.is_valid());
SetGenericData(import_impl.generic_id, new_impl.generic_id, generic_data);
Expand Down Expand Up @@ -1811,10 +1808,8 @@ class ImportRefResolver {

auto parent_scope_id =
GetLocalNameScopeId(import_interface.parent_scope_id);
auto implicit_param_const_ids =
GetLocalParamConstantIds(import_interface.implicit_param_refs_id);
auto param_const_ids =
GetLocalParamConstantIds(import_interface.param_refs_id);
LoadLocalParamConstantIds(import_interface.implicit_param_refs_id);
LoadLocalParamConstantIds(import_interface.param_refs_id);
auto generic_data = GetLocalGenericData(import_interface.generic_id);

std::optional<SemIR::InstId> self_param_id;
Expand All @@ -1828,10 +1823,10 @@ class ImportRefResolver {

auto& new_interface = context_.interfaces().Get(interface_id);
new_interface.parent_scope_id = parent_scope_id;
new_interface.implicit_param_refs_id = GetLocalParamRefsId(
import_interface.implicit_param_refs_id, implicit_param_const_ids);
new_interface.implicit_param_refs_id =
GetLocalParamRefsId(import_interface.implicit_param_refs_id);
new_interface.param_refs_id =
GetLocalParamRefsId(import_interface.param_refs_id, param_const_ids);
GetLocalParamRefsId(import_interface.param_refs_id);
SetGenericData(import_interface.generic_id, new_interface.generic_id,
generic_data);

Expand Down

0 comments on commit e617d64

Please sign in to comment.