From 658f05eb1ef497f8b6160de70e86685637d086e5 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 21 Dec 2024 11:21:55 +0100 Subject: [PATCH 1/4] Fix variable accesses when generating RK equation functions --- lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp index 188c09ee9..65f95440b 100644 --- a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp +++ b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp @@ -974,16 +974,18 @@ FunctionOp RungeKuttaPass::createEquationFunction( rewriter.setInsertionPointToStart(algorithmOp.getBody()); mlir::IRMapping mapping; + llvm::DenseSet mappedInductions; // Get the values of the induction variables. auto originalInductions = explicitEquationOp.getInductionVariables(); for (size_t i = 0, e = originalInductions.size(); i < e; ++i) { - mlir::Value mappedInduction = rewriter.create( + auto mappedInduction = rewriter.create( inductionVariablesOps[i].getLoc(), inductionVariablesOps[i].getVariableType().unwrap(), inductionVariablesOps[i].getSymName()); + mappedInductions.insert(mappedInduction); mapping.map(originalInductions[i], mappedInduction); } @@ -1050,6 +1052,12 @@ FunctionOp RungeKuttaPass::createEquationFunction( } for (VariableGetOp variableGetOp : variableGetOps) { + if (mappedInductions.contains(variableGetOp)) { + // Skip the variables that have been introduced to map the original + // inductions. + continue; + } + rewriter.setInsertionPoint(variableGetOp); if (variableGetOp.getVariable() == mappedStateVariableOp.getSymName()) { @@ -1059,6 +1067,7 @@ FunctionOp RungeKuttaPass::createEquationFunction( VariableOp variableOp = symbolTableCollection.lookupSymbolIn( modelOp, variableGetOp.getVariableAttr()); + assert(variableOp && "Variable not found"); auto futureVariableIt = futureVariables.find(variableOp); if (futureVariableIt == futureVariables.end()) { From fad7ab0e6e2585883d822f2c6033ed6f785c6fea Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 21 Dec 2024 11:40:48 +0100 Subject: [PATCH 2/4] Avoid duplicated function names --- lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp index 65f95440b..a1f0e319c 100644 --- a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp +++ b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp @@ -934,6 +934,7 @@ FunctionOp RungeKuttaPass::createEquationFunction( auto functionOp = rewriter.create(explicitEquationOp.getLoc(), "rk_eq"); + symbolTableCollection.getSymbolTable(moduleOp).insert(functionOp); rewriter.createBlock(&functionOp.getBodyRegion()); // Declare the variables. From 9b5428f37c93e462147c160d10851108ca121394 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 21 Dec 2024 13:23:03 +0100 Subject: [PATCH 3/4] Add expression interface to TensorInsertOp --- .../EquationExpressionOpInterfaceImpl.cpp | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp b/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp index a98722e94..6e719a6b4 100644 --- a/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp +++ b/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp @@ -718,6 +718,29 @@ struct TensorExtractOpInterface } }; +struct TensorInsertOpInterface + : public EquationExpressionOpInterface::ExternalModel< + ::TensorInsertOpInterface, TensorInsertOp> { + void printExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions) const { + auto castedOp = mlir::cast(op); + + ::printExpression(os, castedOp.getValue(), inductions); + os << " into "; + ::printExpression(os, castedOp.getDestination(), inductions); + os << "["; + + llvm::interleaveComma(castedOp.getIndices(), os, [&](mlir::Value exp) { + ::printExpression(os, exp, inductions); + }); + + os << "]"; + } + + DEFINE_DEFAULT_IS_EQUIVALENT(TensorInsertOp) +}; + struct ArrayFromElementsOpInterface : public EquationExpressionOpInterface::ExternalModel< ::ArrayFromElementsOpInterface, ArrayFromElementsOp> { @@ -1980,6 +2003,7 @@ void registerEquationExpressionOpInterfaceExternalModels( TensorBroadcastOp::attachInterface<::TensorBroadcastOpInterface>(*context); TensorViewOp::attachInterface<::TensorViewOpInterface>(*context); TensorExtractOp::attachInterface<::TensorExtractOpInterface>(*context); + TensorInsertOp::attachInterface<::TensorInsertOpInterface>(*context); // Array operations. ArrayFromElementsOp::attachInterface<::ArrayFromElementsOpInterface>(*context); From 77488e67b06a6983c92cdec80569d455a9038b17 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 21 Dec 2024 15:28:21 +0100 Subject: [PATCH 4/4] Fix variable type for additional Runge-Kutta variables --- lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp index a1f0e319c..9a1e0ef4a 100644 --- a/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp +++ b/lib/Dialect/BaseModelica/Transforms/RungeKutta.cpp @@ -795,7 +795,8 @@ VariableOp declareSlopeVariable(mlir::OpBuilder &builder, VariableOp variableOp, "__rk_k" + std::to_string(order) + "_" + variableOp.getSymName().str(); auto variableType = - VariableType::get(std::nullopt, RealType::get(builder.getContext()), + VariableType::get(variableOp.getVariableType().getShape(), + RealType::get(builder.getContext()), VariabilityProperty::none, IOProperty::none); return builder.create(variableOp.getLoc(), name, variableType); @@ -806,7 +807,8 @@ VariableOp declareErrorVariable(mlir::OpBuilder &builder, std::string name = "__rk_e_" + variableOp.getSymName().str(); auto variableType = - VariableType::get(std::nullopt, RealType::get(builder.getContext()), + VariableType::get(variableOp.getVariableType().getShape(), + RealType::get(builder.getContext()), VariabilityProperty::none, IOProperty::none); return builder.create(variableOp.getLoc(), name, variableType); @@ -1293,6 +1295,7 @@ mlir::LogicalResult createSlopeEquation( } callArgs.push_back(secondArg); + llvm::append_range(callArgs, templateOp.getInductionVariables()); auto callOp = rewriter.create(loc, eqRhsFunc.functionOp, callArgs); mlir::Value rhs = callOp.getResult(0);