Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 8, 2025
1 parent 832a20c commit a93e455
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
20 changes: 16 additions & 4 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
mlir::ModuleOp source,
mlir::ModuleOp target,
unsigned &lastUsedID) {
unsigned &lastUsedID,
bool &shouldRemove) {
using namespace llvm;
using namespace mlir;

Expand All @@ -639,6 +640,13 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
return success();
}

if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
if (func.isExternal()) {
shouldRemove = true;
return success();
}
}

StringAttr newSymName = renameSymbol(opName, lastUsedID, source, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
Expand All @@ -658,7 +666,7 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,

unsigned lastUsedID = 0;

for (auto &op : *newMod.getBody()) {
for (auto &op : make_early_inc_range(*newMod.getBody())) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
Expand All @@ -669,10 +677,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
entryFn = &op;
}

if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID))) {
bool shouldRemove = false;
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) {
assert(0 && "failed to update all uses");
}
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
if (shouldRemove)
op.erase();
else
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(
prevMod.getBody()->getOperations().end(),
Expand Down
9 changes: 7 additions & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,14 @@ function get_field_offset(T::Type, path)

for field in path
# Get the field index
field_idx = findfirst(==(field), fieldnames(current_type))
field_idx = if field isa Integer
field
else
@assert field isa Symbol
findfirst(==(field), fieldnames(current_type))
end
if field_idx === nothing
error("Field $field not found in type $current_type")
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
end

# Add the offset of this field
Expand Down

0 comments on commit a93e455

Please sign in to comment.