Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of 6934ab81 (Sep 10, needs llvm bump) (45) #370

Open
wants to merge 3 commits into
base: bump_to_b35675a7
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 2351 files
117 changes: 52 additions & 65 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
/// a single module. If we had to support complex nested symbol references, we
/// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer.
class FlatSymbolRefProgramPoint
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
FlatSymbolRefAttr> {
class FlatSymbolRefLatticeAnchor
: public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
public:
using Base::Base;
void print(raw_ostream &os) const override {
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
}
Location getLoc() const override {
return UnknownLoc::get(getValue().getContext());
return UnknownLoc::get(getValue()->getContext());
}
};

Expand All @@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// State tracking if an IR construct is "safe".
///
/// This state is tracked on Value's and also on global slots (via a
/// FlatSymbolRefProgramPoint).
/// FlatSymbolRefLatticeAnchor).
///
/// In this context, "safe" means that the object is safe to inline.
/// This covers a few concepts
Expand All @@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
(void)setSafe();
}

Expand Down Expand Up @@ -147,33 +146,33 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {

InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<FlatSymbolRefProgramPoint>();
registerAnchorKind<FlatSymbolRefLatticeAnchor>();
}

LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public));
}
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotSet, globalSlotSet.getSlotAttr());

auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotSet.getSlotAttr()));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
for (Value result : op->getResults()) {
if (failed(visit(result)))
return WalkResult::interrupt();
}
if (failed(visit(op)))
return WalkResult::interrupt();

return WalkResult::advance();
});
if (walkResult.wasInterrupted())
Expand All @@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
}

LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = dyn_cast<Value>(point)) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));

// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotGet.getSlotAttr());
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
flatSymbolRefPoint, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}

return success();
}
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
auto it =
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(valueState,
valueState->setSafe(flatSymbolRefState->isSafe));
if (auto op = dyn_cast<Operation *>(point)) {
for (auto value : op->getResults()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));

// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotGet, globalSlotGet.getSlotAttr());
auto *flatSymbolRefPoint =
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
globalSlot, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}
return success();
}
}
LLVM_DEBUG(
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
return failure();

return success();
}

// This is only a member function to access protected get* functions.
Expand All @@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
// safe. This covers, for example, view-like ops that create aliases.
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value.getDefiningOp(), result);
return state->isSafe;
}))
continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = cast<FlatSymbolRefAttr>(
initialize.getSlotSymNames()[use.getOperandNumber()]);
auto globalSlot =
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);

auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
value.getDefiningOp(),
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe)
continue;
}
Expand Down Expand Up @@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
state->print(llvm::dbgs());
llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
Expand Down Expand Up @@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
auto slotSymName =
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
initialize, slotSymName);

auto symbolRefPoint =
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the
// main dataflow analysis, so we need to check the slot's
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
// FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
// For example, a public !torch.int is not safe to inline, even though
// it is a value-semantic type and so the actual initializer value
// itself is conceptually safe to inline.
Expand Down
9 changes: 0 additions & 9 deletions test/Conversion/TorchToStablehlo/linear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_5:.*]] = torch.constant.int 1
// CHECK: %[[T_6:.*]] = torch.constant.int 4
// CHECK: %[[T_7:.*]] = torch.constant.int 3
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int4 = torch.constant.int 4
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
Expand Down Expand Up @@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index
Expand Down
Loading