diff --git a/externals/llvm-project b/externals/llvm-project index 6f289294ba0f..700015696464 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 6f289294ba0fee610ec9e6c736a9fb03686eb23b +Subproject commit 700015696464c13c746d83b02ea2f5e84639c973 diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index ec80d21ef20b..e4893440b6dd 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -49,16 +49,15 @@ using namespace mlir::torch::Torch; /// a single module. If we had to support complex nested symbol references, we /// would probably want to go through the effort to indirect through the symbol /// tables to make things clearer. -class FlatSymbolRefProgramPoint - : public GenericProgramPointBase { +class FlatSymbolRefLatticeAnchor + : public GenericLatticeAnchorBase { public: using Base::Base; void print(raw_ostream &os) const override { - os << "FlatSymbolRefProgramPoint(" << getValue() << ")"; + os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")"; } Location getLoc() const override { - return UnknownLoc::get(getValue().getContext()); + return UnknownLoc::get(getValue()->getContext()); } }; @@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { /// State tracking if an IR construct is "safe". /// /// This state is tracked on Value's and also on global slots (via a -/// FlatSymbolRefProgramPoint). +/// FlatSymbolRefLatticeAnchor). /// /// In this context, "safe" means that the object is safe to inline. /// This covers a few concepts @@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { /// unsafe class InlineGlobalSlotsAnalysisState : public AnalysisState { public: - InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { + InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) { (void)setSafe(); } @@ -147,33 +146,33 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis { InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver) : DataFlowAnalysis(solver) { - registerPointKind(); + registerAnchorKind(); } LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { auto walkResult = top->walk([this](Operation *op) { if (auto globalSlot = dyn_cast(op)) { auto *state = getOrCreate( - getProgramPoint( - FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()))); + getLatticeAnchor(globalSlot)); propagateIfChanged(state, state->setSafe(globalSlot.getVisibility() != SymbolTable::Visibility::Public)); } if (auto globalSlotSet = dyn_cast(op)) { + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + globalSlotSet, globalSlotSet.getSlotAttr()); + auto *state = getOrCreate( - getProgramPoint( - globalSlotSet.getSlotAttr())); + getLatticeAnchor(globalSlot)); propagateIfChanged(state, state->setSafe(false)); } // Save the InitializeGlobalSlotsOp for later referencee if (auto initialize = dyn_cast(op)) { initializeGlobalSlotsOp = initialize; } - for (Value result : op->getResults()) { - if (failed(visit(result))) - return WalkResult::interrupt(); - } + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); }); if (walkResult.wasInterrupted()) @@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { } LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { - if (Value value = dyn_cast(point)) { - bool isSafe = isValueSafeTransferFunction(value); - auto *state = getOrCreate(value); - propagateIfChanged(state, state->setSafe(isSafe)); - - // Handle GlobalSlotGetOp's. - if (auto opResult = dyn_cast(value)) { - if (auto globalSlotGet = - dyn_cast(opResult.getOwner())) { - auto *flatSymbolRefPoint = getProgramPoint( - globalSlotGet.getSlotAttr()); - auto *valueState = getOrCreateFor( - flatSymbolRefPoint, globalSlotGet.getResult()); - auto *globalState = - getOrCreate(flatSymbolRefPoint); - propagateIfChanged(globalState, - globalState->incorporateSafetyOfUse(valueState)); - } - } - - return success(); - } - if (auto *genericProgramPoint = dyn_cast(point)) { - if (auto *flatSymbolRefPoint = - dyn_cast(genericProgramPoint)) { - if (initializeGlobalSlotsOp) { - auto it = - llvm::find(initializeGlobalSlotsOp.getSlotSymNames(), - static_cast(flatSymbolRefPoint->getValue())); - Value value = initializeGlobalSlotsOp->getOperand(std::distance( - initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); - auto *flatSymbolRefState = - getOrCreateFor(value, - flatSymbolRefPoint); - auto *valueState = getOrCreate(value); - propagateIfChanged(valueState, - valueState->setSafe(flatSymbolRefState->isSafe)); + if (auto op = dyn_cast(point)) { + for (auto value : op->getResults()) { + bool isSafe = isValueSafeTransferFunction(value); + auto *state = getOrCreate(value); + propagateIfChanged(state, state->setSafe(isSafe)); + + // Handle GlobalSlotGetOp's. + if (auto opResult = dyn_cast(value)) { + if (auto globalSlotGet = + dyn_cast(opResult.getOwner())) { + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + globalSlotGet, globalSlotGet.getSlotAttr()); + auto *flatSymbolRefPoint = + getLatticeAnchor(globalSlot); + auto *valueState = getOrCreateFor( + globalSlot, globalSlotGet.getResult()); + auto *globalState = + getOrCreate(flatSymbolRefPoint); + propagateIfChanged(globalState, + globalState->incorporateSafetyOfUse(valueState)); + } } - return success(); } } - LLVM_DEBUG( - { llvm::dbgs() << "visit failing because of: " << point << "\n"; }); - return failure(); + + return success(); } // This is only a member function to access protected get* functions. @@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) { // safe. This covers, for example, view-like ops that create aliases. if ((op->hasTrait() || isMemoryEffectFree(op)) && llvm::all_of(op->getResults(), [&](Value result) { - auto *state = - getOrCreateFor(value, result); + auto *state = getOrCreateFor( + value.getDefiningOp(), result); return state->isSafe; })) continue; if (auto initialize = dyn_cast(op)) { auto symName = cast( initialize.getSlotSymNames()[use.getOperandNumber()]); + auto globalSlot = + SymbolTable::lookupNearestSymbolFrom(op, symName); + auto *state = getOrCreateFor( - value, getProgramPoint(symName)); + value.getDefiningOp(), + getLatticeAnchor(globalSlot)); if (state->isSafe) continue; } @@ -299,8 +284,7 @@ class InlineGlobalSlotsPass module->walk([&](Operation *op) { if (auto globalSlot = dyn_cast(op)) { auto *state = solver.lookupState( - solver.getProgramPoint( - FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()))); + solver.getLatticeAnchor(globalSlot)); state->print(llvm::dbgs()); llvm::dbgs() << ": " << FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()) @@ -334,13 +318,16 @@ class InlineGlobalSlotsPass auto slotSymName = cast(initialize.getSlotSymNames()[i]); Value operand = initialize.getOperand(i); - auto symbolRefPoint = solver.getProgramPoint( - cast(initialize.getSlotSymNames()[i])); + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + initialize, slotSymName); + + auto symbolRefPoint = + solver.getLatticeAnchor(globalSlot); auto *state = solver.lookupState(symbolRefPoint); // We roll the analysis of whether a slot is set or public into the // main dataflow analysis, so we need to check the slot's - // FlatSymbolRefProgramPoint itself to see if it is safe to inline. + // FlatSymbolRefLatticeAnchor itself to see if it is safe to inline. // For example, a public !torch.int is not safe to inline, even though // it is a value-semantic type and so the actual initializer value // itself is conceptually safe to inline. diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index ec6bfee2248b..69ec4e2410eb 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_5:.*]] = torch.constant.int 1 // CHECK: %[[T_6:.*]] = torch.constant.int 4 // CHECK: %[[T_7:.*]] = torch.constant.int 3 -// CHECK: %[[T_8:.*]] = arith.constant 3 : i64 // CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list @@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %int2 = torch.constant.int 2 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int4 = torch.constant.int 4 -// CHECK: %[[T_3:.*]] = arith.constant 3 : i64 // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> @@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[T_2:.*]] = arith.constant 2 : i64 -// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> // CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %c0 = arith.constant 0 : index