From 3ac18992f345d9d028fcc15bdc815b80d75ca993 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Wed, 6 Nov 2024 16:21:50 +0100 Subject: [PATCH] Minor code fixes & code formatting --- .../marco/Modeling/AccessFunctionEmpty.h | 67 +- .../marco/Modeling/DimensionAccessConstant.h | 92 +- .../marco/Modeling/DimensionAccessDimension.h | 92 +- .../marco/Modeling/DimensionAccessIndices.h | 98 +- include/public/marco/Modeling/Matching.h | 2790 ++- .../BaseModelicaToTensor.cpp | 880 +- lib/Dialect/BaseModelica/IR/Ops.cpp | 17312 ++++++++-------- lib/Dialect/BaseModelica/IR/Types.cpp | 894 +- .../BaseModelica/Transforms/EulerForward.cpp | 351 +- .../Transforms/FunctionInlining.cpp | 1175 +- lib/Dialect/BaseModelica/Transforms/IDA.cpp | 3 - .../Transforms/RecordInlining.cpp | 2026 +- .../Transforms/SCCSolvingBySubstitution.cpp | 3 +- .../Transforms/SCCSolvingWithKINSOL.cpp | 3 - lib/Frontend/CompilerInvocation.cpp | 32 +- lib/Modeling/AccessFunctionConstant.cpp | 2 +- lib/Modeling/DimensionAccess.cpp | 329 +- lib/Modeling/DimensionAccessConstant.cpp | 163 +- lib/Modeling/DimensionAccessDimension.cpp | 162 +- lib/Modeling/DimensionAccessIndices.cpp | 210 +- 20 files changed, 12289 insertions(+), 14395 deletions(-) diff --git a/include/public/marco/Modeling/AccessFunctionEmpty.h b/include/public/marco/Modeling/AccessFunctionEmpty.h index abf2d62a8..0c355abd2 100644 --- a/include/public/marco/Modeling/AccessFunctionEmpty.h +++ b/include/public/marco/Modeling/AccessFunctionEmpty.h @@ -3,60 +3,53 @@ #include "marco/Modeling/AccessFunctionAffineMap.h" -namespace marco::modeling -{ - class AccessFunctionEmpty : public AccessFunctionAffineMap - { - public: - static bool canBeBuilt( - uint64_t numOfDimensions, - llvm::ArrayRef> results); +namespace marco::modeling { +class AccessFunctionEmpty : public AccessFunctionAffineMap { +public: + static bool + canBeBuilt(uint64_t numOfDimensions, + llvm::ArrayRef> results); - static bool canBeBuilt(mlir::AffineMap affineMap); + static bool canBeBuilt(mlir::AffineMap affineMap); - explicit AccessFunctionEmpty(mlir::AffineMap affineMap); + explicit AccessFunctionEmpty(mlir::AffineMap affineMap); - AccessFunctionEmpty( - mlir::MLIRContext* context, - uint64_t numOfDimensions, - llvm::ArrayRef> results); + AccessFunctionEmpty(mlir::MLIRContext *context, uint64_t numOfDimensions, + llvm::ArrayRef> results); - ~AccessFunctionEmpty() override; + ~AccessFunctionEmpty() override; - [[nodiscard]] std::unique_ptr clone() const override; + [[nodiscard]] std::unique_ptr clone() const override; - /// @name LLVM-style RTTI methods - /// { + /// @name LLVM-style RTTI methods + /// { - static bool classof(const AccessFunction* obj) - { - return obj->getKind() == Empty; - } + static bool classof(const AccessFunction *obj) { + return obj->getKind() == Empty; + } - /// } + /// } - [[nodiscard]] bool operator==( - const AccessFunction& other) const override; + [[nodiscard]] bool operator==(const AccessFunction &other) const override; - [[nodiscard]] bool operator==(const AccessFunctionEmpty& other) const; + [[nodiscard]] bool operator==(const AccessFunctionEmpty &other) const; - [[nodiscard]] bool operator!=( - const AccessFunction& other) const override; + [[nodiscard]] bool operator!=(const AccessFunction &other) const override; - [[nodiscard]] bool operator!=(const AccessFunctionEmpty& other) const; + [[nodiscard]] bool operator!=(const AccessFunctionEmpty &other) const; - [[nodiscard]] bool isInvertible() const override; + [[nodiscard]] bool isInvertible() const override; - [[nodiscard]] std::unique_ptr inverse() const override; + [[nodiscard]] std::unique_ptr inverse() const override; - [[nodiscard]] IndexSet map(const Point& point) const override; + [[nodiscard]] IndexSet map(const Point &point) const override; - [[nodiscard]] IndexSet map(const IndexSet& indices) const override; + [[nodiscard]] IndexSet map(const IndexSet &indices) const override; - [[nodiscard]] IndexSet inverseMap( - const IndexSet& accessedIndices, - const IndexSet& parentIndices) const override; + [[nodiscard]] IndexSet + inverseMap(const IndexSet &accessedIndices, + const IndexSet &parentIndices) const override; }; -} +} // namespace marco::modeling #endif // MARCO_MODELING_ACCESSFUNCTIONEMPTY_H diff --git a/include/public/marco/Modeling/DimensionAccessConstant.h b/include/public/marco/Modeling/DimensionAccessConstant.h index 80d2a7cfa..27f3c02aa 100644 --- a/include/public/marco/Modeling/DimensionAccessConstant.h +++ b/include/public/marco/Modeling/DimensionAccessConstant.h @@ -3,78 +3,54 @@ #include "marco/Modeling/DimensionAccess.h" -namespace marco::modeling -{ - class DimensionAccessConstant : public DimensionAccess - { - public: - DimensionAccessConstant(mlir::MLIRContext* context, int64_t value); +namespace marco::modeling { +class DimensionAccessConstant : public DimensionAccess { +public: + DimensionAccessConstant(mlir::MLIRContext *context, int64_t value); - DimensionAccessConstant(const DimensionAccessConstant& other); + static bool classof(const DimensionAccess *obj) { + return obj->getKind() == DimensionAccess::Kind::Constant; + } - DimensionAccessConstant(DimensionAccessConstant&& other) noexcept; + [[nodiscard]] std::unique_ptr clone() const override; - ~DimensionAccessConstant() override; + [[nodiscard]] bool operator==(const DimensionAccess &other) const override; - DimensionAccessConstant& operator=(const DimensionAccessConstant& other); + [[nodiscard]] bool operator==(const DimensionAccessConstant &other) const; - DimensionAccessConstant& operator=( - DimensionAccessConstant&& other) noexcept; + [[nodiscard]] bool operator!=(const DimensionAccess &other) const override; - friend void swap( - DimensionAccessConstant& first, DimensionAccessConstant& second); + [[nodiscard]] bool operator!=(const DimensionAccessConstant &other) const; - static bool classof(const DimensionAccess* obj) - { - return obj->getKind() == DimensionAccess::Kind::Constant; - } + llvm::raw_ostream &dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const override; - [[nodiscard]] std::unique_ptr clone() const override; + void collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const override; - [[nodiscard]] bool operator==( - const DimensionAccess& other) const override; + void collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const override; - [[nodiscard]] bool operator==( - const DimensionAccessConstant& other) const; + [[nodiscard]] bool isAffine() const override; - [[nodiscard]] bool operator!=( - const DimensionAccess& other) const override; + [[nodiscard]] mlir::AffineExpr getAffineExpr() const override; - [[nodiscard]] bool operator!=( - const DimensionAccessConstant& other) const; + [[nodiscard]] mlir::AffineExpr + getAffineExpr(unsigned int numOfDimensions, + FakeDimensionsMap &fakeDimensionsMap) const override; - llvm::raw_ostream& dump( - llvm::raw_ostream& os, - const llvm::DenseMap& iterationSpacesIds) - const override; + [[nodiscard]] IndexSet map(const Point &point, + llvm::DenseMap + ¤tIndexSetsPoint) const override; - void collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const override; + [[nodiscard]] int64_t getValue() const; - void collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const override; - - [[nodiscard]] bool isAffine() const override; - - [[nodiscard]] mlir::AffineExpr getAffineExpr() const override; - - [[nodiscard]] mlir::AffineExpr getAffineExpr( - unsigned int numOfDimensions, - FakeDimensionsMap& fakeDimensionsMap) const override; - - [[nodiscard]] IndexSet map( - const Point& point, - llvm::DenseMap< - const IndexSet*, Point>& currentIndexSetsPoint) const override; - - [[nodiscard]] int64_t getValue() const; - - private: - int64_t value; - }; -} +private: + int64_t value; +}; +} // namespace marco::modeling #endif // MARCO_MODELING_DIMENSIONACCESSCONSTANT_H diff --git a/include/public/marco/Modeling/DimensionAccessDimension.h b/include/public/marco/Modeling/DimensionAccessDimension.h index e694292fc..821dbef77 100644 --- a/include/public/marco/Modeling/DimensionAccessDimension.h +++ b/include/public/marco/Modeling/DimensionAccessDimension.h @@ -3,78 +3,54 @@ #include "marco/Modeling/DimensionAccess.h" -namespace marco::modeling -{ - class DimensionAccessDimension : public DimensionAccess - { - public: - DimensionAccessDimension(mlir::MLIRContext* context, uint64_t dimension); +namespace marco::modeling { +class DimensionAccessDimension : public DimensionAccess { +public: + DimensionAccessDimension(mlir::MLIRContext *context, uint64_t dimension); - DimensionAccessDimension(const DimensionAccessDimension& other); + static bool classof(const DimensionAccess *obj) { + return obj->getKind() == DimensionAccess::Kind::Dimension; + } - DimensionAccessDimension(DimensionAccessDimension&& other) noexcept; + [[nodiscard]] std::unique_ptr clone() const override; - ~DimensionAccessDimension() override; + [[nodiscard]] bool operator==(const DimensionAccess &other) const override; - DimensionAccessDimension& operator=(const DimensionAccessDimension& other); + [[nodiscard]] bool operator==(const DimensionAccessDimension &other) const; - DimensionAccessDimension& operator=( - DimensionAccessDimension&& other) noexcept; + [[nodiscard]] bool operator!=(const DimensionAccess &other) const override; - friend void swap( - DimensionAccessDimension& first, DimensionAccessDimension& second); + [[nodiscard]] bool operator!=(const DimensionAccessDimension &other) const; - static bool classof(const DimensionAccess* obj) - { - return obj->getKind() == DimensionAccess::Kind::Dimension; - } + llvm::raw_ostream &dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const override; - [[nodiscard]] std::unique_ptr clone() const override; + void collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const override; - [[nodiscard]] bool operator==( - const DimensionAccess& other) const override; + void collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const override; - [[nodiscard]] bool operator==( - const DimensionAccessDimension& other) const; + [[nodiscard]] bool isAffine() const override; - [[nodiscard]] bool operator!=( - const DimensionAccess& other) const override; + [[nodiscard]] mlir::AffineExpr getAffineExpr() const override; - [[nodiscard]] bool operator!=( - const DimensionAccessDimension& other) const; + [[nodiscard]] mlir::AffineExpr + getAffineExpr(unsigned int numOfDimensions, + FakeDimensionsMap &fakeDimensionsMap) const override; - llvm::raw_ostream& dump( - llvm::raw_ostream& os, - const llvm::DenseMap& iterationSpacesIds) - const override; + [[nodiscard]] IndexSet map(const Point &point, + llvm::DenseMap + ¤tIndexSetsPoint) const override; - void collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const override; + [[nodiscard]] uint64_t getDimension() const; - void collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const override; - - [[nodiscard]] bool isAffine() const override; - - [[nodiscard]] mlir::AffineExpr getAffineExpr() const override; - - [[nodiscard]] mlir::AffineExpr getAffineExpr( - unsigned int numOfDimensions, - FakeDimensionsMap& fakeDimensionsMap) const override; - - [[nodiscard]] IndexSet map( - const Point& point, - llvm::DenseMap< - const IndexSet*, Point>& currentIndexSetsPoint) const override; - - [[nodiscard]] uint64_t getDimension() const; - - private: - uint64_t dimension; - }; -} +private: + uint64_t dimension; +}; +} // namespace marco::modeling #endif // MARCO_MODELING_DIMENSIONACCESSDIMENSION_H diff --git a/include/public/marco/Modeling/DimensionAccessIndices.h b/include/public/marco/Modeling/DimensionAccessIndices.h index 1291cb570..934673065 100644 --- a/include/public/marco/Modeling/DimensionAccessIndices.h +++ b/include/public/marco/Modeling/DimensionAccessIndices.h @@ -5,80 +5,56 @@ #include "marco/Modeling/IndexSet.h" #include "llvm/ADT/DenseSet.h" -namespace marco::modeling -{ - class DimensionAccessIndices : public DimensionAccess - { - public: - DimensionAccessIndices( - mlir::MLIRContext* context, - std::shared_ptr space, - uint64_t dimension, - llvm::DenseSet dimensionDependencies); +namespace marco::modeling { +class DimensionAccessIndices : public DimensionAccess { +public: + DimensionAccessIndices(mlir::MLIRContext *context, + std::shared_ptr space, uint64_t dimension, + llvm::DenseSet dimensionDependencies); - DimensionAccessIndices(const DimensionAccessIndices& other); + static bool classof(const DimensionAccess *obj) { + return obj->getKind() == DimensionAccess::Kind::Indices; + } - DimensionAccessIndices(DimensionAccessIndices&& other) noexcept; + [[nodiscard]] std::unique_ptr clone() const override; - ~DimensionAccessIndices() override; + [[nodiscard]] bool operator==(const DimensionAccess &other) const override; - DimensionAccessIndices& operator=(const DimensionAccessIndices& other); + [[nodiscard]] bool operator==(const DimensionAccessIndices &other) const; - DimensionAccessIndices& operator=( - DimensionAccessIndices&& other) noexcept; + [[nodiscard]] bool operator!=(const DimensionAccess &other) const override; - friend void swap( - DimensionAccessIndices& first, DimensionAccessIndices& second); + [[nodiscard]] bool operator!=(const DimensionAccessIndices &other) const; - static bool classof(const DimensionAccess* obj) - { - return obj->getKind() == DimensionAccess::Kind::Indices; - } + llvm::raw_ostream &dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const override; - [[nodiscard]] std::unique_ptr clone() const override; + void collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const override; - [[nodiscard]] bool operator==( - const DimensionAccess& other) const override; + void collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const override; - [[nodiscard]] bool operator==(const DimensionAccessIndices& other) const; + [[nodiscard]] mlir::AffineExpr + getAffineExpr(unsigned int numOfDimensions, + FakeDimensionsMap &fakeDimensionsMap) const override; - [[nodiscard]] bool operator!=( - const DimensionAccess& other) const override; + [[nodiscard]] IndexSet map(const Point &point, + llvm::DenseMap + ¤tIndexSetsPoint) const override; - [[nodiscard]] bool operator!=(const DimensionAccessIndices& other) const; + [[nodiscard]] IndexSet &getIndices(); - llvm::raw_ostream& dump( - llvm::raw_ostream& os, - const llvm::DenseMap& iterationSpacesIds) - const override; + [[nodiscard]] const IndexSet &getIndices() const; - void collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const override; - - void collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const override; - - [[nodiscard]] mlir::AffineExpr getAffineExpr( - unsigned int numOfDimensions, - FakeDimensionsMap& fakeDimensionsMap) const override; - - [[nodiscard]] IndexSet map( - const Point& point, - llvm::DenseMap< - const IndexSet*, Point>& currentIndexSetsPoint) const override; - - [[nodiscard]] IndexSet& getIndices(); - - [[nodiscard]] const IndexSet& getIndices() const; - - private: - std::shared_ptr space; - uint64_t dimension; - llvm::DenseSet dimensionDependencies; - }; -} +private: + std::shared_ptr space; + uint64_t dimension; + llvm::DenseSet dimensionDependencies; +}; +} // namespace marco::modeling #endif // MARCO_MODELING_DIMENSIONACCESSINDICES_H diff --git a/include/public/marco/Modeling/Matching.h b/include/public/marco/Modeling/Matching.h index 30631dac8..b35d60407 100644 --- a/include/public/marco/Modeling/Matching.h +++ b/include/public/marco/Modeling/Matching.h @@ -1,18 +1,22 @@ #ifndef MARCO_MODELING_MATCHING_H #define MARCO_MODELING_MATCHING_H -#include "marco/Modeling/TreeOStream.h" +#ifndef DEBUG_TYPE +#define DEBUG_TYPE "matching" +#endif + #include "marco/Modeling/AccessFunction.h" #include "marco/Modeling/Dumpable.h" #include "marco/Modeling/Graph.h" #include "marco/Modeling/LocalMatchingSolutions.h" #include "marco/Modeling/MCIM.h" #include "marco/Modeling/Range.h" +#include "marco/Modeling/TreeOStream.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Threading.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -24,1840 +28,1590 @@ #include #include -#define DEBUG_TYPE "matching" +namespace marco::modeling { +namespace matching { +// This class must be specialized for the variable type that is used during +// the matching process. +template +struct VariableTraits { + // Elements to provide: + // + // typedef Id : the ID type of the variable. + // + // static Id getId(const VariableType*) + // return the ID of the variable. + // + // static size_t getRank(const VariableType*) + // return the number of dimensions. + // + // static IndexSet getIndices(const VariableType*) + // return the indices of a variable. + // + // static llvm::raw_ostream& dump(const VariableType*, llvm::raw_ostream&) + // print debug information. + + using Id = typename VariableType::UnknownVariableTypeError; +}; + +// This class must be specialized for the equation type that is used during +// the matching process. +template +struct EquationTraits { + // Elements to provide: + // + // typedef Id : the ID type of the equation. + // + // static Id getId(const EquationType*) + // return the ID of the equation. + // + // static size_t getNumOfIterationVars(const EquationType*) + // return the number of induction variables. + // + // static MultidimensionalRange getIterationRanges(const EquationType*) + // return the iteration ranges. + // + // typedef VariableType : the type of the accessed variable + // + // typedef AccessProperty : the access property (this is optional, and if + // not specified an empty one is used) + // + // static std::vector> + // getAccesses(const EquationType*) + // return the accesses done by the equation. + // + // static llvm::raw_ostream& dump(const EquationType*, llvm::raw_ostream&) + // print debug information. + + using Id = typename EquationType::UnknownEquationTypeError; +}; +} // namespace matching + +namespace internal { +namespace matching { +/// Represent a generic vectorized entity whose scalar elements +/// can be matched with the scalar elements of other arrays. +/// The relationship is tracked by means of an incidence matrix. +class Matchable { +public: + explicit Matchable(IndexSet matchableIndices); + + const IndexSet &getMatched() const; + + IndexSet getUnmatched() const; + + /// Check whether all the scalar elements of this array have been matched. + bool allComponentsMatched() const; + + void addMatch(const IndexSet &newMatch); + + void removeMatch(const IndexSet &removedMatch); + +private: + IndexSet matchableIndices; + IndexSet match; +}; + +/// Graph node representing a variable. +template +class VariableVertex : public Matchable, public Dumpable { +public: + using Property = VariableProperty; + using Traits = + typename ::marco::modeling::matching::VariableTraits; + using Id = typename Traits::Id; + + explicit VariableVertex(VariableProperty property) + : Matchable(getIndices(property)), property(property), visible(true) { + // Scalar variables can be represented by means of an array with just one + // element + assert(getRank() > 0 && "Scalar variables are not supported"); + } -namespace marco::modeling -{ - namespace matching - { - // This class must be specialized for the variable type that is used during - // the matching process. - template - struct VariableTraits - { - // Elements to provide: - // - // typedef Id : the ID type of the variable. - // - // static Id getId(const VariableType*) - // return the ID of the variable. - // - // static size_t getRank(const VariableType*) - // return the number of dimensions. - // - // static IndexSet getIndices(const VariableType*) - // return the indices of a variable. - // - // static llvm::raw_ostream& dump(const VariableType*, llvm::raw_ostream&) - // print debug information. - - using Id = typename VariableType::UnknownVariableTypeError; - }; + using Dumpable::dump; + + void dump(llvm::raw_ostream &os) const override { + os << "Variable\n"; + os << " - ID: " << getId() << "\n"; + os << " - Rank: " << getRank() << "\n"; + os << " - Indices: " << getIndices() << "\n"; + os << " - Matched: " << getMatched() << "\n"; + os << " - Details: "; + Traits::dump(&property, os); + os << "\n"; + } - // This class must be specialized for the equation type that is used during - // the matching process. - template - struct EquationTraits - { - // Elements to provide: - // - // typedef Id : the ID type of the equation. - // - // static Id getId(const EquationType*) - // return the ID of the equation. - // - // static size_t getNumOfIterationVars(const EquationType*) - // return the number of induction variables. - // - // static MultidimensionalRange getIterationRanges(const EquationType*) - // return the iteration ranges. - // - // typedef VariableType : the type of the accessed variable - // - // typedef AccessProperty : the access property (this is optional, and if - // not specified an empty one is used) - // - // static std::vector> - // getAccesses(const EquationType*) - // return the accesses done by the equation. - // - // static llvm::raw_ostream& dump(const EquationType*, llvm::raw_ostream&) - // print debug information. - - using Id = typename EquationType::UnknownEquationTypeError; - }; + VariableProperty &getProperty() { return property; } + + const VariableProperty &getProperty() const { return property; } + + Id getId() const { return Traits::getId(&property); } + + size_t getRank() const { + auto result = getRank(property); + assert(result > 0); + return result; } - namespace internal - { - namespace matching - { - /// Represent a generic vectorized entity whose scalar elements - /// can be matched with the scalar elements of other arrays. - /// The relationship is tracked by means of an incidence matrix. - class Matchable - { - public: - explicit Matchable(IndexSet matchableIndices); - - const IndexSet& getMatched() const; - - IndexSet getUnmatched() const; - - /// Check whether all the scalar elements of this array have been matched. - bool allComponentsMatched() const; - - void addMatch(const IndexSet& newMatch); - - void removeMatch(const IndexSet& removedMatch); - - private: - IndexSet matchableIndices; - IndexSet match; - }; - - /// Graph node representing a variable. - template - class VariableVertex : public Matchable, public Dumpable - { - public: - using Property = VariableProperty; - using Traits = typename ::marco::modeling::matching::VariableTraits; - using Id = typename Traits::Id; - - explicit VariableVertex(VariableProperty property) - : Matchable(getIndices(property)), - property(property), - visible(true) - { - // Scalar variables can be represented by means of an array with just one element - assert(getRank() > 0 && "Scalar variables are not supported"); - } + IndexSet getIndices() const { + auto result = getIndices(property); + assert(!result.empty()); + return result; + } - using Dumpable::dump; - - void dump(llvm::raw_ostream& os) const override - { - os << "Variable\n"; - os << " - ID: " << getId() << "\n"; - os << " - Rank: " << getRank() << "\n"; - os << " - Indices: " << getIndices() << "\n"; - os << " - Matched: " << getMatched() << "\n"; - os << " - Details: "; - Traits::dump(&property, os); - os << "\n"; - } + unsigned int flatSize() const { return getIndices().flatSize(); } - VariableProperty& getProperty() - { - return property; - } + bool isVisible() const { return visible; } - const VariableProperty& getProperty() const - { - return property; - } + void setVisibility(bool visibility) { visible = visibility; } - Id getId() const - { - return Traits::getId(&property); - } +private: + static size_t getRank(const VariableProperty &p) { + return Traits::getRank(&p); + } - size_t getRank() const - { - auto result = getRank(property); - assert(result > 0); - return result; - } + static IndexSet getIndices(const VariableProperty &p) { + return Traits::getIndices(&p); + } - IndexSet getIndices() const - { - auto result = getIndices(property); - assert(!result.empty()); - return result; - } + // Custom equation property + VariableProperty property; - unsigned int flatSize() const - { - return getIndices().flatSize(); - } + // Whether the node is visible or has been erased + bool visible; +}; - bool isVisible() const - { - return visible; - } +class EmptyAccessProperty {}; - void setVisibility(bool visibility) - { - visible = visibility; - } +template +struct get_access_property { + template + using Traits = ::marco::modeling::matching::EquationTraits; - private: - static size_t getRank(const VariableProperty& p) - { - return Traits::getRank(&p); - } + template ::AccessProperty> + static typename Traits::AccessProperty property(int); - static IndexSet getIndices(const VariableProperty& p) - { - return Traits::getIndices(&p); - } + template + static EmptyAccessProperty property(...); - // Custom equation property - VariableProperty property; + using type = decltype(property(0)); +}; +} // namespace matching +} // namespace internal - // Whether the node is visible or has been erased - bool visible; - }; +namespace matching { +template +class Access { +public: + using Property = AccessProperty; - class EmptyAccessProperty - { - }; + Access(const VariableProperty &variable, + std::unique_ptr accessFunction, + AccessProperty property = {}) + : variable(VariableTraits::getId(&variable)), + accessFunction(std::move(accessFunction)), + property(std::move(property)) {} - template - struct get_access_property - { - template - using Traits = ::marco::modeling::matching::EquationTraits; + Access(const Access &other) + : variable(other.variable), accessFunction(other.accessFunction->clone()), + property(other.property) {} - template::AccessProperty> - static typename Traits::AccessProperty property(int); + ~Access() = default; - template - static EmptyAccessProperty property(...); + typename VariableTraits::Id getVariable() const { + return variable; + } - using type = decltype(property(0)); - }; - } + const AccessFunction &getAccessFunction() const { + assert(accessFunction != nullptr); + return *accessFunction; } - namespace matching - { - template< - typename VariableProperty, - typename AccessProperty = internal::matching::EmptyAccessProperty> - class Access - { - public: - using Property = AccessProperty; - - Access( - const VariableProperty& variable, - std::unique_ptr accessFunction, - AccessProperty property = {}) - : variable(VariableTraits::getId(&variable)), - accessFunction(std::move(accessFunction)), - property(std::move(property)) - { - } + const AccessProperty &getProperty() const { return property; } + +private: + typename VariableTraits::Id variable; + std::unique_ptr accessFunction; + AccessProperty property; +}; +} // namespace matching + +namespace internal::matching { +template +void insertOrAdd(std::map &map, T key, IndexSet value) { + if (auto it = map.find(key); it != map.end()) { + it->second += std::move(value); + } else { + map.emplace(key, std::move(value)); + } +} - Access(const Access& other) - : variable(other.variable), - accessFunction(other.accessFunction->clone()), - property(other.property) - { - } +/// Graph node representing an equation. +template +class EquationVertex : public Matchable, public Dumpable { +public: + using Property = EquationProperty; + using Traits = + typename ::marco::modeling::matching::EquationTraits; + using Id = typename Traits::Id; + + using Access = ::marco::modeling::matching::Access< + typename Traits::VariableType, + typename get_access_property::type>; + + EquationVertex(EquationProperty property) + : Matchable(IndexSet(getIterationRanges(property))), property(property), + visible(true) {} + + using Dumpable::dump; + + void dump(llvm::raw_ostream &os) const override { + os << "Equation\n"; + os << " - ID: " << getId() << "\n"; + os << " - Iteration ranges: " << getIterationRanges() << "\n"; + os << " - Matched: " << getMatched() << "\n"; + os << " - Details: "; + Traits::dump(&property, os); + os << "\n"; + } - ~Access() = default; + EquationProperty &getProperty() { return property; } - typename VariableTraits::Id getVariable() const - { - return variable; - } + const EquationProperty &getProperty() const { return property; } - const AccessFunction& getAccessFunction() const - { - assert(accessFunction != nullptr); - return *accessFunction; - } + Id getId() const { return Traits::getId(&property); } - const AccessProperty& getProperty() const - { - return property; - } + size_t getNumOfIterationVars() const { + auto result = getNumOfIterationVars(property); + assert(result > 0); + return result; + } - private: - typename VariableTraits::Id variable; - std::unique_ptr accessFunction; - AccessProperty property; - }; + IndexSet getIterationRanges() const { + auto result = getIterationRanges(property); + assert(!result.empty()); + return result; } - namespace internal::matching - { - template - void insertOrAdd(std::map& map, T key, IndexSet value) - { - if (auto it = map.find(key); it != map.end()) { - it->second += std::move(value); - } else { - map.emplace(key, std::move(value)); - } - } + unsigned int flatSize() const { return getIterationRanges().flatSize(); } - /// Graph node representing an equation. - template - class EquationVertex : public Matchable, public Dumpable - { - public: - using Property = EquationProperty; - using Traits = typename ::marco::modeling::matching::EquationTraits; - using Id = typename Traits::Id; - - using Access = ::marco::modeling::matching::Access< - typename Traits::VariableType, - typename get_access_property::type>; - - EquationVertex(EquationProperty property) - : Matchable(IndexSet(getIterationRanges(property))), - property(property), - visible(true) - { - } + std::vector getVariableAccesses() const { + return Traits::getAccesses(&property); + } - using Dumpable::dump; - - void dump(llvm::raw_ostream& os) const override - { - os << "Equation\n"; - os << " - ID: " << getId() << "\n"; - os << " - Iteration ranges: " << getIterationRanges() << "\n"; - os << " - Matched: " << getMatched() << "\n"; - os << " - Details: "; - Traits::dump(&property, os); - os << "\n"; - } + bool isVisible() const { return visible; } - EquationProperty& getProperty() - { - return property; - } + void setVisibility(bool visibility) { visible = visibility; } - const EquationProperty& getProperty() const - { - return property; - } +private: + static size_t getNumOfIterationVars(const EquationProperty &p) { + return Traits::getNumOfIterationVars(&p); + } - Id getId() const - { - return Traits::getId(&property); - } + static IndexSet getIterationRanges(const EquationProperty &p) { + return Traits::getIterationRanges(&p); + } - size_t getNumOfIterationVars() const - { - auto result = getNumOfIterationVars(property); - assert(result > 0); - return result; - } + // Custom equation property + EquationProperty property; + + // Whether the node is visible or has been erased + bool visible; +}; + +template +class Edge : public Dumpable { +public: + using AccessProperty = typename Equation::Access::Property; + + Edge(typename Equation::Id equation, typename Variable::Id variable, + IndexSet equationRanges, IndexSet variableRanges, + typename Equation::Access access) + : equation(std::move(equation)), variable(std::move(variable)), + accessFunction(access.getAccessFunction().clone()), + accessProperty(access.getProperty()), + incidenceMatrix(equationRanges, variableRanges), + matchMatrix(equationRanges, variableRanges), visible(true) { + incidenceMatrix.apply(getAccessFunction()); + } - IndexSet getIterationRanges() const - { - auto result = getIterationRanges(property); - assert(!result.empty()); - return result; - } + using Dumpable::dump; - unsigned int flatSize() const - { - return getIterationRanges().flatSize(); - } + void dump(llvm::raw_ostream &os) const override { + os << "Edge\n"; + os << " - Equation: " << equation << "\n"; + os << " - Variable: " << variable << "\n"; + os << " - Incidence matrix:\n" << incidenceMatrix << "\n"; + os << " - Matched equations: " << getMatched().flattenColumns() << "\n"; + os << " - Matched variables: " << getMatched().flattenRows() << "\n"; + os << " - Match matrix:\n" << getMatched() << "\n"; + } - std::vector getVariableAccesses() const - { - return Traits::getAccesses(&property); - } + const AccessFunction &getAccessFunction() const { return *accessFunction; } - bool isVisible() const - { - return visible; - } + const AccessProperty &getAccessProperty() const { return accessProperty; } - void setVisibility(bool visibility) - { - visible = visibility; - } + const MCIM &getIncidenceMatrix() const { return incidenceMatrix; } - private: - static size_t getNumOfIterationVars(const EquationProperty& p) - { - return Traits::getNumOfIterationVars(&p); - } + void addMatch(const MCIM &match) { matchMatrix += match; } - static IndexSet getIterationRanges(const EquationProperty& p) - { - return Traits::getIterationRanges(&p); - } + void removeMatch(const MCIM &match) { matchMatrix -= match; } - // Custom equation property - EquationProperty property; + const MCIM &getMatched() const { return matchMatrix; } - // Whether the node is visible or has been erased - bool visible; - }; + MCIM getUnmatched() const { return incidenceMatrix - matchMatrix; } - template - class Edge : public Dumpable - { - public: - using AccessProperty = typename Equation::Access::Property; - - Edge(typename Equation::Id equation, - typename Variable::Id variable, - IndexSet equationRanges, - IndexSet variableRanges, - typename Equation::Access access) - : equation(std::move(equation)), - variable(std::move(variable)), - accessFunction(access.getAccessFunction().clone()), - accessProperty(access.getProperty()), - incidenceMatrix(equationRanges, variableRanges), - matchMatrix(equationRanges, variableRanges), - visible(true) - { - incidenceMatrix.apply(getAccessFunction()); - } + bool isVisible() const { return visible; } - using Dumpable::dump; - - void dump(llvm::raw_ostream& os) const override - { - os << "Edge\n"; - os << " - Equation: " << equation << "\n"; - os << " - Variable: " << variable << "\n"; - os << " - Incidence matrix:\n" << incidenceMatrix << "\n"; - os << " - Matched equations: " << getMatched().flattenColumns() << "\n"; - os << " - Matched variables: " << getMatched().flattenRows() << "\n"; - os << " - Match matrix:\n" << getMatched() << "\n"; - } + void setVisibility(bool visibility) { visible = visibility; } - const AccessFunction& getAccessFunction() const - { - return *accessFunction; - } +private: + // Equation's ID. Just for debugging purpose + typename Equation::Id equation; - const AccessProperty& getAccessProperty() const - { - return accessProperty; - } + // Variable's ID. Just for debugging purpose + typename Variable::Id variable; - const MCIM& getIncidenceMatrix() const - { - return incidenceMatrix; - } + std::unique_ptr accessFunction; + AccessProperty accessProperty; + MCIM incidenceMatrix; + MCIM matchMatrix; - void addMatch(const MCIM& match) - { - matchMatrix += match; - } + bool visible; +}; - void removeMatch(const MCIM& match) - { - matchMatrix -= match; - } +template +class BFSStep : public Dumpable { +public: + using VertexDescriptor = typename Graph::VertexDescriptor; + using EdgeDescriptor = typename Graph::EdgeDescriptor; - const MCIM& getMatched() const - { - return matchMatrix; - } + using VertexProperty = typename Graph::VertexProperty; - MCIM getUnmatched() const - { - return incidenceMatrix - matchMatrix; - } + BFSStep(const Graph &graph, VertexDescriptor node, IndexSet candidates) + : graph(&graph), previous(nullptr), node(std::move(node)), + candidates(std::move(candidates)), edge(std::nullopt), + mappedFlow(std::nullopt) {} - bool isVisible() const - { - return visible; - } + BFSStep(const Graph &graph, BFSStep previous, EdgeDescriptor edge, + VertexDescriptor node, IndexSet candidates, MCIM mappedFlow) + : graph(&graph), previous(std::make_unique(std::move(previous))), + node(std::move(node)), candidates(std::move(candidates)), + edge(std::move(edge)), mappedFlow(std::move(mappedFlow)) {} - void setVisibility(bool visibility) - { - visible = visibility; - } + BFSStep(const BFSStep &other) + : graph(other.graph), + previous(other.hasPrevious() + ? std::make_unique(*other.previous) + : nullptr), + node(other.node), candidates(other.candidates), edge(other.edge), + mappedFlow(other.mappedFlow) {} - private: - // Equation's ID. Just for debugging purpose - typename Equation::Id equation; + ~BFSStep() = default; - // Variable's ID. Just for debugging purpose - typename Variable::Id variable; + BFSStep &operator=(const BFSStep &other); - std::unique_ptr accessFunction; - AccessProperty accessProperty; - MCIM incidenceMatrix; - MCIM matchMatrix; + template + friend void swap(BFSStep &first, BFSStep &second); - bool visible; - }; + using Dumpable::dump; - template - class BFSStep : public Dumpable - { - public: - using VertexDescriptor = typename Graph::VertexDescriptor; - using EdgeDescriptor = typename Graph::EdgeDescriptor; - - using VertexProperty = typename Graph::VertexProperty; - - BFSStep(const Graph& graph, - VertexDescriptor node, - IndexSet candidates) - : graph(&graph), - previous(nullptr), - node(std::move(node)), - candidates(std::move(candidates)), - edge(std::nullopt), - mappedFlow(std::nullopt) - { - } + void dump(llvm::raw_ostream &os) const override { + os << "BFS step\n"; - BFSStep(const Graph& graph, - BFSStep previous, - EdgeDescriptor edge, - VertexDescriptor node, - IndexSet candidates, - MCIM mappedFlow) - : graph(&graph), - previous(std::make_unique(std::move(previous))), - node(std::move(node)), - candidates(std::move(candidates)), - edge(std::move(edge)), - mappedFlow(std::move(mappedFlow)) - { - } + os << "Node: "; + dumpId(os, getNode()); + os << "\n"; - BFSStep(const BFSStep& other) - : graph(other.graph), - previous(other.hasPrevious() ? std::make_unique(*other.previous) : nullptr), - node(other.node), - candidates(other.candidates), - edge(other.edge), - mappedFlow(other.mappedFlow) - { - } + os << "Candidates:\n" << getCandidates(); - ~BFSStep() = default; + if (hasPrevious()) { + os << "\n"; + os << "Edge: "; + dumpId(os, getEdge().from); + os << " - "; + dumpId(os, getEdge().to); + os << "\n"; - BFSStep& operator=(const BFSStep& other); + os << "Mapped flow:\n" << getMappedFlow() << "\n"; + os << "Previous:\n"; + getPrevious()->dump(os); + } + } - template - friend void swap(BFSStep& first, BFSStep& second); + bool hasPrevious() const { return previous != nullptr; } - using Dumpable::dump; + const BFSStep *getPrevious() const { return previous.get(); } - void dump(llvm::raw_ostream& os) const override - { - os << "BFS step\n"; + const VertexDescriptor &getNode() const { return node; } - os << "Node: "; - dumpId(os, getNode()); - os << "\n"; + const IndexSet &getCandidates() const { return candidates; } - os << "Candidates:\n" << getCandidates(); + const EdgeDescriptor &getEdge() const { + assert(edge.has_value()); + return *edge; + } - if (hasPrevious()) { - os << "\n"; - os << "Edge: "; - dumpId(os, getEdge().from); - os << " - "; - dumpId(os, getEdge().to); - os << "\n"; + const MCIM &getMappedFlow() const { + assert(mappedFlow.has_value()); + return *mappedFlow; + } - os << "Mapped flow:\n" << getMappedFlow() << "\n"; - os << "Previous:\n"; - getPrevious()->dump(os); - } - } +private: + void dumpId(llvm::raw_ostream &os, VertexDescriptor descriptor) const { + const VertexProperty &nodeProperty = (*graph)[descriptor]; - bool hasPrevious() const - { - return previous != nullptr; - } + if (std::holds_alternative(nodeProperty)) { + os << std::get(nodeProperty).getId(); + } else { + os << std::get(nodeProperty).getId(); + } + } - const BFSStep *getPrevious() const - { - return previous.get(); - } +private: + // Stored for debugging purpose + const Graph *graph; + + std::unique_ptr previous; + VertexDescriptor node; + IndexSet candidates; + std::optional edge; + std::optional mappedFlow; +}; + +template +void swap(BFSStep &first, + BFSStep &second) { + using std::swap; + + swap(first.previous, second.previous); + swap(first.node, second.node); + swap(first.candidates, second.candidates); + swap(first.edge, second.edge); + swap(first.mappedFlow, second.mappedFlow); +} - const VertexDescriptor& getNode() const - { - return node; - } +template +BFSStep & +BFSStep::operator=( + const BFSStep &other) { + BFSStep result(other); + swap(*this, result); + return *this; +} - const IndexSet& getCandidates() const - { - return candidates; - } +template +class Frontier : public Dumpable { +private: + template + using Container = std::vector; - const EdgeDescriptor& getEdge() const - { - assert(edge.has_value()); - return *edge; - } +public: + using iterator = typename Container::iterator; + using const_iterator = typename Container::const_iterator; - const MCIM& getMappedFlow() const - { - assert(mappedFlow.has_value()); - return *mappedFlow; - } + using Dumpable::dump; - private: - void dumpId(llvm::raw_ostream& os, VertexDescriptor descriptor) const - { - const VertexProperty& nodeProperty = (*graph)[descriptor]; + void dump(llvm::raw_ostream &os) const override { + os << "Frontier\n"; - if (std::holds_alternative(nodeProperty)) { - os << std::get(nodeProperty).getId(); - } else { - os << std::get(nodeProperty).getId(); - } - } + for (const auto &step : steps) { + step.dump(os); + os << "\n"; + } + } - private: - // Stored for debugging purpose - const Graph* graph; + friend void swap(Frontier &first, Frontier &second) { + using std::swap; - std::unique_ptr previous; - VertexDescriptor node; - IndexSet candidates; - std::optional edge; - std::optional mappedFlow; - }; + swap(first.steps, second.steps); + } - template - void swap(BFSStep& first, - BFSStep& second) - { - using std::swap; - - swap(first.previous, second.previous); - swap(first.node, second.node); - swap(first.candidates, second.candidates); - swap(first.edge, second.edge); - swap(first.mappedFlow, second.mappedFlow); - } + BFSStep &operator[](size_t index) { + assert(index < steps.size()); + return steps[index]; + } - template - BFSStep& - BFSStep::operator=( - const BFSStep& other) - { - BFSStep result(other); - swap(*this, result); - return *this; - } + const BFSStep &operator[](size_t index) const { + assert(index < steps.size()); + return steps[index]; + } - template - class Frontier : public Dumpable - { - private: - template using Container = std::vector; + bool empty() const { return steps.empty(); } - public: - using iterator = typename Container::iterator; - using const_iterator = typename Container::const_iterator; + template + void emplace(Args &&...args) { + steps.emplace_back(args...); + } - using Dumpable::dump; + void clear() { steps.clear(); } - void dump(llvm::raw_ostream& os) const override - { - os << "Frontier\n"; + void swap(Frontier &other) { steps.swap(other.steps); } - for (const auto& step : steps) { - step.dump(os); - os << "\n"; - } - } + iterator begin() { return steps.begin(); } - friend void swap(Frontier& first, Frontier& second) - { - using std::swap; + const_iterator begin() const { return steps.begin(); } - swap(first.steps, second.steps); - } + iterator end() { return steps.end(); } - BFSStep& operator[](size_t index) - { - assert(index < steps.size()); - return steps[index]; - } + const_iterator end() const { return steps.end(); } - const BFSStep& operator[](size_t index) const - { - assert(index < steps.size()); - return steps[index]; - } +private: + Container steps; +}; - bool empty() const - { - return steps.empty(); - } +template +class Flow : public Dumpable { +private: + using VertexDescriptor = typename Graph::VertexDescriptor; + using EdgeDescriptor = typename Graph::EdgeDescriptor; - template - void emplace(Args&& ... args) - { - steps.emplace_back(args...); - } + using VertexProperty = typename Graph::VertexProperty; - void clear() - { - steps.clear(); - } +public: + Flow(const Graph &graph, VertexDescriptor source, EdgeDescriptor edge, + const MCIM &delta) + : graph(&graph), source(std::move(source)), edge(std::move(edge)), + delta(std::move(delta)) { + assert(this->source == this->edge.from || this->source == this->edge.to); + } - void swap(Frontier& other) - { - steps.swap(other.steps); - } + using Dumpable::dump; - iterator begin() - { - return steps.begin(); - } + void dump(llvm::raw_ostream &os) const override { + os << "Flow\n"; - const_iterator begin() const - { - return steps.begin(); - } + os << " - Source: "; + dumpId(os, source); + os << "\n"; - iterator end() - { - return steps.end(); - } + os << " - Edge: "; + dumpId(os, edge.from); + os << " - "; + dumpId(os, edge.to); + os << "\n"; - const_iterator end() const - { - return steps.end(); - } + os << " - Delta:\n" << delta; + } - private: - Container steps; - }; +private: + void dumpId(llvm::raw_ostream &os, VertexDescriptor descriptor) const { + const VertexProperty &nodeProperty = (*graph)[descriptor]; - template - class Flow : public Dumpable - { - private: - using VertexDescriptor = typename Graph::VertexDescriptor; - using EdgeDescriptor = typename Graph::EdgeDescriptor; - - using VertexProperty = typename Graph::VertexProperty; - - public: - Flow(const Graph& graph, VertexDescriptor source, EdgeDescriptor edge, const MCIM& delta) - : graph(&graph), - source(std::move(source)), - edge(std::move(edge)), - delta(std::move(delta)) - { - assert(this->source == this->edge.from || this->source == this->edge.to); - } + if (std::holds_alternative(nodeProperty)) { + os << std::get(nodeProperty).getId(); + } else { + os << std::get(nodeProperty).getId(); + } + } - using Dumpable::dump; +private: + // Stored for debugging purpose + const Graph *graph; - void dump(llvm::raw_ostream& os) const override - { - os << "Flow\n"; +public: + const VertexDescriptor source; + const EdgeDescriptor edge; + const MCIM delta; +}; - os << " - Source: "; - dumpId(os, source); - os << "\n"; +template +class AugmentingPath : public Dumpable { +private: + template + using Container = std::vector; - os << " - Edge: "; - dumpId(os, edge.from); - os << " - "; - dumpId(os, edge.to); - os << "\n"; +public: + using const_iterator = typename Container::const_iterator; - os << " - Delta:\n" << delta; - } + template + explicit AugmentingPath(const Flows &flows) + : flows(flows.begin(), flows.end()) {} - private: - void dumpId(llvm::raw_ostream& os, VertexDescriptor descriptor) const - { - const VertexProperty& nodeProperty = (*graph)[descriptor]; + using Dumpable::dump; - if (std::holds_alternative(nodeProperty)) { - os << std::get(nodeProperty).getId(); - } else { - os << std::get(nodeProperty).getId(); - } - } + void dump(llvm::raw_ostream &os) const override { + os << "Augmenting path\n"; - private: - // Stored for debugging purpose - const Graph* graph; + for (const auto &flow : flows) { + os << " "; + flow.dump(os); + os << "\n"; + } + } - public: - const VertexDescriptor source; - const EdgeDescriptor edge; - const MCIM delta; - }; + const Flow &operator[](size_t index) const { + assert(index < flows.size()); + return flows[index]; + } - template - class AugmentingPath : public Dumpable - { - private: - template using Container = std::vector; + const_iterator begin() const { return flows.begin(); } - public: - using const_iterator = typename Container::const_iterator; + const_iterator end() const { return flows.end(); } - template - explicit AugmentingPath(const Flows& flows) - : flows(flows.begin(), flows.end()) - { - } +private: + Container flows; +}; - using Dumpable::dump; +/// Represents how an equation has been matched (i.e. the selected indexes and +/// access). +template +class MatchingSolution { +public: + MatchingSolution(EquationProperty equation, VariableProperty variable, + IndexSet indexes, AccessProperty access) + : equation(std::move(equation)), variable(std::move(variable)), + indexes(std::move(indexes)), access(std::move(access)) {} - void dump(llvm::raw_ostream& os) const override - { - os << "Augmenting path\n"; + EquationProperty &getEquation() { return equation; } - for (const auto& flow : flows) { - os << " "; - flow.dump(os); - os << "\n"; - } - } + const EquationProperty &getEquation() const { return equation; } - const Flow& operator[](size_t index) const - { - assert(index < flows.size()); - return flows[index]; - } + const VariableProperty &getVariable() const { return variable; } - const_iterator begin() const - { - return flows.begin(); - } + const AccessProperty &getAccess() const { return access; } - const_iterator end() const - { - return flows.end(); - } + const IndexSet &getIndexes() const { return indexes; } - private: - Container flows; - }; +private: + EquationProperty equation; + VariableProperty variable; + IndexSet indexes; + AccessProperty access; +}; +} // namespace internal::matching - /// Represents how an equation has been matched (i.e. the selected indexes and access). - template - class MatchingSolution - { - public: - MatchingSolution( - EquationProperty equation, - VariableProperty variable, - IndexSet indexes, - AccessProperty access) - : equation(std::move(equation)), - variable(std::move(variable)), - indexes(std::move(indexes)), - access(std::move(access)) - { - } +template +class MatchingGraph : public internal::Dumpable { +public: + using Variable = internal::matching::VariableVertex; + using Equation = internal::matching::EquationVertex; + using Vertex = std::variant; + using Edge = internal::matching::Edge; - EquationProperty& getEquation() - { - return equation; - } +private: + using Graph = internal::UndirectedGraph; - const EquationProperty& getEquation() const - { - return equation; - } + using VertexDescriptor = typename Graph::VertexDescriptor; + using EdgeDescriptor = typename Graph::EdgeDescriptor; - const VariableProperty& getVariable() const - { - return variable; - } + using VertexIterator = typename Graph::VertexIterator; + using EdgeIterator = typename Graph::EdgeIterator; + using VisibleIncidentEdgeIterator = + typename Graph::FilteredIncidentEdgeIterator; - const AccessProperty& getAccess() const - { - return access; - } + using MCIM = internal::MCIM; + using BFSStep = internal::matching::BFSStep; + using Frontier = internal::matching::Frontier; + using Flow = internal::matching::Flow; + using AugmentingPath = internal::matching::AugmentingPath; - const IndexSet& getIndexes() const - { - return indexes; - } +public: + using VariableIterator = typename Graph::FilteredVertexIterator; + using EquationIterator = typename Graph::FilteredVertexIterator; - private: - EquationProperty equation; - VariableProperty variable; - IndexSet indexes; - AccessProperty access; - }; - } + using AccessProperty = typename Equation::Access::Property; + using Access = matching::Access; + using MatchingSolution = + internal::matching::MatchingSolution; - template - class MatchingGraph : public internal::Dumpable - { - public: - using Variable = internal::matching::VariableVertex; - using Equation = internal::matching::EquationVertex; - using Vertex = std::variant; - using Edge = internal::matching::Edge; + explicit MatchingGraph(mlir::MLIRContext *context) : context(context) {} - private: - using Graph = internal::UndirectedGraph; + MatchingGraph(const MatchingGraph &other) = delete; - using VertexDescriptor = typename Graph::VertexDescriptor; - using EdgeDescriptor = typename Graph::EdgeDescriptor; + MatchingGraph(MatchingGraph &&other) { + std::lock_guard lockGuard(other.mutex); - using VertexIterator = typename Graph::VertexIterator; - using EdgeIterator = typename Graph::EdgeIterator; - using VisibleIncidentEdgeIterator = typename Graph::FilteredIncidentEdgeIterator; + context = std::move(other.context); + graph = std::move(other.graph); + variablesMap = std::move(other.variablesMap); + equationsMap = std::move(other.equationsMap); + } - using MCIM = internal::MCIM; - using BFSStep = internal::matching::BFSStep; - using Frontier = internal::matching::Frontier; - using Flow = internal::matching::Flow; - using AugmentingPath = internal::matching::AugmentingPath; + ~MatchingGraph() = default; - public: - using VariableIterator = typename Graph::FilteredVertexIterator; - using EquationIterator = typename Graph::FilteredVertexIterator; + MatchingGraph &operator=(const MatchingGraph &other) = delete; - using AccessProperty = typename Equation::Access::Property; - using Access = matching::Access; - using MatchingSolution = internal::matching::MatchingSolution; + MatchingGraph &operator=(MatchingGraph &&other) = default; - explicit MatchingGraph(mlir::MLIRContext* context) - : context(context) - { - } + using Dumpable::dump; - MatchingGraph(const MatchingGraph& other) = delete; + void dump(llvm::raw_ostream &os) const override { + os << "--------------------------------------------------\n"; + os << "Matching graph\n"; + os << "- Vertices:\n"; - MatchingGraph(MatchingGraph&& other) - { - std::lock_guard lockGuard(other.mutex); + for (auto vertexDescriptor : + llvm::make_range(graph.verticesBegin(), graph.verticesEnd())) { - context = std::move(other.context); - graph = std::move(other.graph); - variablesMap = std::move(other.variablesMap); - equationsMap = std::move(other.equationsMap); - } + std::visit([&](const auto &vertex) { vertex.dump(os); }, + graph[vertexDescriptor]); - ~MatchingGraph() = default; + os << "\n"; + } - MatchingGraph& operator=(const MatchingGraph& other) = delete; + os << "- Equations:\n"; - MatchingGraph& operator=(MatchingGraph&& other) = default; + for (auto equationDescriptor : + llvm::make_range(graph.edgesBegin(), graph.edgesEnd())) { + graph[equationDescriptor].dump(os); + os << "\n"; + } - using Dumpable::dump; + os << "--------------------------------------------------\n"; + } - void dump(llvm::raw_ostream& os) const override - { - os << "--------------------------------------------------\n"; - os << "Matching graph\n"; - os << "- Vertices:\n"; + mlir::MLIRContext *getContext() const { + assert(context != nullptr); + return context; + } - for (auto vertexDescriptor : llvm::make_range( - graph.verticesBegin(), graph.verticesEnd())) { + bool hasVariable(typename Variable::Id id) const { + std::lock_guard lockGuard(mutex); + return hasVariableWithId(id); + } - std::visit( - [&](const auto& vertex) { - vertex.dump(os); - }, graph[vertexDescriptor]); + VariableProperty &getVariable(typename Variable::Id id) { + std::lock_guard lockGuard(mutex); + return getVariablePropertyFromId(id); + } - os << "\n"; - } + const VariableProperty &getVariable(typename Variable::Id id) const { + std::lock_guard lockGuard(mutex); + return getVariablePropertyFromId(id); + } - os << "- Equations:\n"; + Variable &getVariable(VertexDescriptor descriptor) { + std::lock_guard lockGuard(mutex); + return getVariableFromDescriptor(descriptor); + } - for (auto equationDescriptor : llvm::make_range( - graph.edgesBegin(), graph.edgesEnd())) { - graph[equationDescriptor].dump(os); - os << "\n"; - } + const Variable &getVariable(VertexDescriptor descriptor) const { + std::lock_guard lockGuard(mutex); + return getVariableFromDescriptor(descriptor); + } - os << "--------------------------------------------------\n"; - } + VariableIterator variablesBegin() const { + std::lock_guard lockGuard(mutex); + return getVariablesBeginIt(); + } - mlir::MLIRContext* getContext() const - { - assert(context != nullptr); - return context; - } + VariableIterator variablesEnd() const { + std::lock_guard lockGuard(mutex); + return getVariablesEndIt(); + } - bool hasVariable(typename Variable::Id id) const - { - std::lock_guard lockGuard(mutex); - return hasVariableWithId(id); - } + void addVariable(VariableProperty property) { + std::lock_guard lockGuard(mutex); - VariableProperty& getVariable(typename Variable::Id id) - { - std::lock_guard lockGuard(mutex); - return getVariablePropertyFromId(id); - } + Variable variable(std::move(property)); + auto id = variable.getId(); + assert(!hasVariableWithId(id) && "Already existing variable"); + VertexDescriptor variableDescriptor = graph.addVertex(std::move(variable)); + variablesMap[id] = variableDescriptor; + } - const VariableProperty& getVariable(typename Variable::Id id) const - { - std::lock_guard lockGuard(mutex); - return getVariablePropertyFromId(id); - } + bool hasEquation(typename Equation::Id id) const { + std::lock_guard lockGuard(mutex); + return hasEquationWithId(id); + } - Variable& getVariable(VertexDescriptor descriptor) - { - std::lock_guard lockGuard(mutex); - return getVariableFromDescriptor(descriptor); - } + EquationProperty &getEquation(typename Equation::Id id) { + std::lock_guard lockGuard(mutex); + return getEquationPropertyFromId(id); + } - const Variable& getVariable(VertexDescriptor descriptor) const - { - std::lock_guard lockGuard(mutex); - return getVariableFromDescriptor(descriptor); - } + const EquationProperty &getEquation(typename Equation::Id id) const { + std::lock_guard lockGuard(mutex); + return getEquationPropertyFromId(id); + } - VariableIterator variablesBegin() const - { - std::lock_guard lockGuard(mutex); - return getVariablesBeginIt(); - } + Equation &getEquation(VertexDescriptor descriptor) { + std::lock_guard lockGuard(mutex); + return getEquationFromDescriptor(descriptor); + } - VariableIterator variablesEnd() const - { - std::lock_guard lockGuard(mutex); - return getVariablesEndIt(); - } + const Equation &getEquation(VertexDescriptor descriptor) const { + std::lock_guard lockGuard(mutex); + return getEquationFromDescriptor(descriptor); + } - void addVariable(VariableProperty property) - { - std::lock_guard lockGuard(mutex); + EquationIterator equationsBegin() const { + std::lock_guard lockGuard(mutex); + return getEquationsBeginIt(); + } - Variable variable(std::move(property)); - auto id = variable.getId(); - assert(!hasVariableWithId(id) && "Already existing variable"); - VertexDescriptor variableDescriptor = graph.addVertex(std::move(variable)); - variablesMap[id] = variableDescriptor; - } + EquationIterator equationsEnd() const { + std::lock_guard lockGuard(mutex); + return getEquationsEndIt(); + } - bool hasEquation(typename Equation::Id id) const - { - std::lock_guard lockGuard(mutex); - return hasEquationWithId(id); - } + void addEquation(EquationProperty property) { + Equation eq(std::move(property)); + [[maybe_unused]] auto id = eq.getId(); - EquationProperty& getEquation(typename Equation::Id id) - { - std::lock_guard lockGuard(mutex); - return getEquationPropertyFromId(id); - } + std::unique_lock lockGuard(mutex); + assert(!hasEquationWithId(id) && "Already existing equation"); - const EquationProperty& getEquation(typename Equation::Id id) const - { - std::lock_guard lockGuard(mutex); - return getEquationPropertyFromId(id); - } + // Insert the equation into the graph and get a reference to the new vertex + VertexDescriptor equationDescriptor = graph.addVertex(std::move(eq)); + equationsMap[id] = equationDescriptor; + Equation &equation = getEquationFromDescriptor(equationDescriptor); + lockGuard.unlock(); - Equation& getEquation(VertexDescriptor descriptor) - { - std::lock_guard lockGuard(mutex); - return getEquationFromDescriptor(descriptor); - } + // The equation may access multiple variables or even multiple indexes + // of the same variable. Add an edge to the graph for each of those + // accesses. - const Equation& getEquation(VertexDescriptor descriptor) const - { - std::lock_guard lockGuard(mutex); - return getEquationFromDescriptor(descriptor); - } + IndexSet equationRanges = equation.getIterationRanges(); - EquationIterator equationsBegin() const - { - std::lock_guard lockGuard(mutex); - return getEquationsBeginIt(); - } + for (const auto &access : equation.getVariableAccesses()) { + lockGuard.lock(); + VertexDescriptor variableDescriptor = + getVariableDescriptorFromId(access.getVariable()); + Variable &variable = getVariableFromDescriptor(variableDescriptor); + lockGuard.unlock(); - EquationIterator equationsEnd() const - { - std::lock_guard lockGuard(mutex); - return getEquationsEndIt(); + IndexSet indices = variable.getIndices().getCanonicalRepresentation(); + + for (const MultidimensionalRange &range : + llvm::make_range(indices.rangesBegin(), indices.rangesEnd())) { + Edge edge(equation.getId(), variable.getId(), equationRanges, + IndexSet(range), access); + + lockGuard.lock(); + graph.addEdge(equationDescriptor, variableDescriptor, std::move(edge)); + lockGuard.unlock(); } + } + } - void addEquation(EquationProperty property) - { - Equation eq(std::move(property)); - [[maybe_unused]] auto id = eq.getId(); + /// Get the total amount of scalar variables inside the graph. + size_t getNumberOfScalarVariables() const { + std::lock_guard lockGuard(mutex); + size_t result = 0; - std::unique_lock lockGuard(mutex); - assert(!hasEquationWithId(id) && "Already existing equation"); + auto variables = + llvm::make_range(getVariablesBeginIt(), getVariablesEndIt()); - // Insert the equation into the graph and get a reference to the new vertex - VertexDescriptor equationDescriptor = graph.addVertex(std::move(eq)); - equationsMap[id] = equationDescriptor; - Equation& equation = getEquationFromDescriptor(equationDescriptor); - lockGuard.unlock(); + for (VertexDescriptor variableDescriptor : variables) { + result += getVariableFromDescriptor(variableDescriptor).flatSize(); + } - // The equation may access multiple variables or even multiple indexes - // of the same variable. Add an edge to the graph for each of those - // accesses. + return result; + } - IndexSet equationRanges = equation.getIterationRanges(); + /// Get the total amount of scalar equations inside the graph. + /// With "scalar equations" we mean the ones generated by unrolling + /// the loops defining them. + size_t getNumberOfScalarEquations() const { + std::lock_guard lockGuard(mutex); + size_t result = 0; - for (const auto& access : equation.getVariableAccesses()) { - lockGuard.lock(); - VertexDescriptor variableDescriptor = getVariableDescriptorFromId(access.getVariable()); - Variable& variable = getVariableFromDescriptor(variableDescriptor); - lockGuard.unlock(); + auto equations = + llvm::make_range(getEquationsBeginIt(), getEquationsEndIt()); - IndexSet indices = variable.getIndices().getCanonicalRepresentation(); + for (VertexDescriptor equationDescriptor : equations) { + result += getEquationFromDescriptor(equationDescriptor).flatSize(); + } - for (const MultidimensionalRange& range : - llvm::make_range(indices.rangesBegin(), indices.rangesEnd())) { - Edge edge(equation.getId(), variable.getId(), equationRanges, IndexSet(range), access); + return result; + } - lockGuard.lock(); - graph.addEdge(equationDescriptor, variableDescriptor, std::move(edge)); - lockGuard.unlock(); - } - } - } + // Warning: highly inefficient, use for testing purposes only. + bool hasEdge(typename Equation::Id equationId, + typename Variable::Id variableId) const { + std::lock_guard lockGuard(mutex); - /// Get the total amount of scalar variables inside the graph. - size_t getNumberOfScalarVariables() const - { - std::lock_guard lockGuard(mutex); - size_t result = 0; + if (findEdge(equationId, variableId).first) { + return true; + } - auto variables = llvm::make_range(getVariablesBeginIt(), getVariablesEndIt()); + return findEdge(variableId, equationId).first; + } - for (VertexDescriptor variableDescriptor : variables) { - result += getVariableFromDescriptor(variableDescriptor).flatSize(); - } + /// Apply the simplification algorithm in order to perform all + /// the obligatory matches, that is the variables and equations + /// having only one incident edge. + /// + /// @return true if the simplification algorithm didn't find any inconsistency + bool simplify() { + std::lock_guard lockGuard(mutex); - return result; - } + // Vertices that are candidate for the first simplification phase. + // They are the ones having only one incident edge. + std::list candidates; - /// Get the total amount of scalar equations inside the graph. - /// With "scalar equations" we mean the ones generated by unrolling - /// the loops defining them. - size_t getNumberOfScalarEquations() const - { - std::lock_guard lockGuard(mutex); - size_t result = 0; + if (!collectSimplifiableNodes(candidates)) { + return false; + } - auto equations = llvm::make_range(getEquationsBeginIt(), getEquationsEndIt()); + // Check that the list of simplifiable nodes does not contain + // duplicates. + assert(llvm::all_of(candidates, + [&](const VertexDescriptor &vertex) { + return llvm::count(candidates, vertex) == 1; + }) && + "Duplicates found in the list of simplifiable nodes"); - for (VertexDescriptor equationDescriptor : equations) { - result += getEquationFromDescriptor(equationDescriptor).flatSize(); - } + // Iterate on the candidate vertices and apply the simplification algorithm + auto isVisibleFn = [](const auto &obj) -> bool { return obj.isVisible(); }; - return result; - } + auto allComponentsMatchedFn = [](const auto &vertex) -> bool { + return vertex.allComponentsMatched(); + }; - // Warning: highly inefficient, use for testing purposes only. - bool hasEdge(typename Equation::Id equationId, typename Variable::Id variableId) const - { - std::lock_guard lockGuard(mutex); + while (!candidates.empty()) { + VertexDescriptor v1 = candidates.front(); + candidates.pop_front(); - if (findEdge(equationId, variableId).first) { - return true; - } + if (const Vertex &v = graph[v1]; !std::visit(isVisibleFn, v)) { + // The current node, which initially had one and only one incident + // edge, has been removed more by simplifications performed in the + // previous iterations. We could just remove the vertex while the + // edge was removed, but that would have required iterating over + // the whole candidates list, thus worsening the overall complexity + // of the algorithm. - return findEdge(variableId, equationId).first; + assert(std::visit(allComponentsMatchedFn, v)); + continue; } - /// Apply the simplification algorithm in order to perform all - /// the obligatory matches, that is the variables and equations - /// having only one incident edge. - /// - /// @return true if the simplification algorithm didn't find any inconsistency - bool simplify() - { - std::lock_guard lockGuard(mutex); + EdgeDescriptor edgeDescriptor = getFirstOutVisibleEdge(v1); + Edge &edge = graph[edgeDescriptor]; - // Vertices that are candidate for the first simplification phase. - // They are the ones having only one incident edge. - std::list candidates; + VertexDescriptor v2 = + edgeDescriptor.from == v1 ? edgeDescriptor.to : edgeDescriptor.from; - if (!collectSimplifiableNodes(candidates)) { - return false; - } + const auto &u = edge.getIncidenceMatrix(); - // Check that the list of simplifiable nodes does not contain - // duplicates. - assert(llvm::all_of(candidates, [&](const VertexDescriptor& vertex) { - return llvm::count(candidates, vertex) == 1; - }) && "Duplicates found in the list of simplifiable nodes"); - - // Iterate on the candidate vertices and apply the simplification algorithm - auto isVisibleFn = [](const auto& obj) -> bool { - return obj.isVisible(); - }; - - auto allComponentsMatchedFn = [](const auto& vertex) -> bool { - return vertex.allComponentsMatched(); - }; - - while (!candidates.empty()) { - VertexDescriptor v1 = candidates.front(); - candidates.pop_front(); - - if (const Vertex& v = graph[v1]; !std::visit(isVisibleFn, v)) { - // The current node, which initially had one and only one incident - // edge, has been removed more by simplifications performed in the - // previous iterations. We could just remove the vertex while the - // edge was removed, but that would have required iterating over - // the whole candidates list, thus worsening the overall complexity - // of the algorithm. - - assert(std::visit(allComponentsMatchedFn, v)); - continue; - } + auto matchOptions = internal::solveLocalMatchingProblem( + u.getEquationRanges(), u.getVariableRanges(), + edge.getAccessFunction().clone()); - EdgeDescriptor edgeDescriptor = getFirstOutVisibleEdge(v1); - Edge& edge = graph[edgeDescriptor]; + // The simplification steps is executed only in case of a single + // matching option. In case of multiple ones, in fact, the choice + // would be arbitrary and may affect the feasibility of the + // array-aware matching problem. - VertexDescriptor v2 = edgeDescriptor.from == v1 ? edgeDescriptor.to : edgeDescriptor.from; + if (matchOptions.size() == 1) { + const MCIM &match = matchOptions[0]; - const auto& u = edge.getIncidenceMatrix(); + Variable &variable = isVariable(v1) ? getVariableFromDescriptor(v1) + : getVariableFromDescriptor(v2); + Equation &equation = isEquation(v1) ? getEquationFromDescriptor(v1) + : getEquationFromDescriptor(v2); - auto matchOptions = internal::solveLocalMatchingProblem( - u.getEquationRanges(), - u.getVariableRanges(), - edge.getAccessFunction().clone()); + IndexSet proposedVariableMatch = match.flattenRows(); + IndexSet proposedEquationMatch = match.flattenColumns(); - // The simplification steps is executed only in case of a single - // matching option. In case of multiple ones, in fact, the choice - // would be arbitrary and may affect the feasibility of the - // array-aware matching problem. + MCIM reducedMatch = + match.filterColumns(proposedVariableMatch - variable.getMatched()); + reducedMatch = reducedMatch.filterRows(proposedEquationMatch - + equation.getMatched()); - if (matchOptions.size() == 1) { - const MCIM& match = matchOptions[0]; + edge.addMatch(reducedMatch); - Variable& variable = isVariable(v1) ? getVariableFromDescriptor(v1) : getVariableFromDescriptor(v2); - Equation& equation = isEquation(v1) ? getEquationFromDescriptor(v1) : getEquationFromDescriptor(v2); + IndexSet reducedVariableMatch = reducedMatch.flattenRows(); + IndexSet reducedEquationMatch = reducedMatch.flattenColumns(); - IndexSet proposedVariableMatch = match.flattenRows(); - IndexSet proposedEquationMatch = match.flattenColumns(); + variable.addMatch(reducedVariableMatch); + equation.addMatch(reducedEquationMatch); - MCIM reducedMatch = match.filterColumns(proposedVariableMatch - variable.getMatched()); - reducedMatch = reducedMatch.filterRows(proposedEquationMatch - equation.getMatched()); + if (!std::visit(allComponentsMatchedFn, graph[v1])) { + return false; + } - edge.addMatch(reducedMatch); + bool shouldRemoveOppositeNode = + std::visit(allComponentsMatchedFn, graph[v2]); - IndexSet reducedVariableMatch = reducedMatch.flattenRows(); - IndexSet reducedEquationMatch = reducedMatch.flattenColumns(); + // Remove the edge and the current candidate vertex. + remove(edgeDescriptor); + remove(v1); - variable.addMatch(reducedVariableMatch); - equation.addMatch(reducedEquationMatch); + if (shouldRemoveOppositeNode) { + // When a node is removed, then also its incident edges are + // removed. This can lead to new obliged matches, like in the + // following example: + // |-- v3 ---- + // v1 -- v2 | + // |-- v4 -- v5 + // v1 is the current candidate and thus is removed. + // v2 is removed because fully matched. + // v3 and v4 become new candidates for the simplification pass. - if (!std::visit(allComponentsMatchedFn, graph[v1])) { - return false; - } + for (EdgeDescriptor e : llvm::make_range( + graph.outgoingEdgesBegin(v2), graph.outgoingEdgesEnd(v2))) { + remove(e); - bool shouldRemoveOppositeNode = std::visit(allComponentsMatchedFn, graph[v2]); - - // Remove the edge and the current candidate vertex. - remove(edgeDescriptor); - remove(v1); - - if (shouldRemoveOppositeNode) { - // When a node is removed, then also its incident edges are - // removed. This can lead to new obliged matches, like in the - // following example: - // |-- v3 ---- - // v1 -- v2 | - // |-- v4 -- v5 - // v1 is the current candidate and thus is removed. - // v2 is removed because fully matched. - // v3 and v4 become new candidates for the simplification pass. - - for (EdgeDescriptor e : llvm::make_range(graph.outgoingEdgesBegin(v2), graph.outgoingEdgesEnd(v2))) { - remove(e); - - VertexDescriptor v = e.from == v2 ? e.to : e.from; - - if (!std::visit(isVisibleFn, graph[v])) { - continue; - } - - size_t visibilityDegree = getVertexVisibilityDegree(v); - - if (visibilityDegree == 0) { - // Chained simplifications may have led the 'v' vertex - // without any edge. In that case, it must have been fully - // matched during the process. - - if (!std::visit(allComponentsMatchedFn, graph[v])) { - return false; - } - - // 'v' will also be present for sure in the candidates list. - // However, being fully matched and having no outgoing edge, - // we now must remove it. - remove(v); - } else if (visibilityDegree == 1) { - candidates.push_back(v); - } - } + VertexDescriptor v = e.from == v2 ? e.to : e.from; - // Remove the v2 vertex. - remove(v2); - } else { - // When an edge is removed but one of its vertices survives, we must - // check if the remaining vertex has an obliged match. + if (!std::visit(isVisibleFn, graph[v])) { + continue; + } + + size_t visibilityDegree = getVertexVisibilityDegree(v); - size_t visibilityDegree = getVertexVisibilityDegree(v2); + if (visibilityDegree == 0) { + // Chained simplifications may have led the 'v' vertex + // without any edge. In that case, it must have been fully + // matched during the process. - if (visibilityDegree == 1) { - candidates.push_back(v2); + if (!std::visit(allComponentsMatchedFn, graph[v])) { + return false; } + + // 'v' will also be present for sure in the candidates list. + // However, being fully matched and having no outgoing edge, + // we now must remove it. + remove(v); + } else if (visibilityDegree == 1) { + candidates.push_back(v); } } - } - return true; - } + // Remove the v2 vertex. + remove(v2); + } else { + // When an edge is removed but one of its vertices survives, we must + // check if the remaining vertex has an obliged match. - /// Apply the matching algorithm. - /// - /// @return true if the matching algorithm managed to fully match all the nodes - bool match() - { - std::lock_guard lockGuard(mutex); + size_t visibilityDegree = getVertexVisibilityDegree(v2); - if (allNodesMatched()) { - return true; + if (visibilityDegree == 1) { + candidates.push_back(v2); + } } + } + } - bool success; - bool complete; - - do { - success = matchIteration(); - complete = allNodesMatched(); - - LLVM_DEBUG({ - llvm::dbgs() << "Match iteration completed\n"; - dump(llvm::dbgs()); - }); - } while (success && !complete); - - LLVM_DEBUG({ - if (success) { - llvm::dbgs() << "Matching completed successfully\n"; - } else { - llvm::dbgs() << "Matching failed\n"; - } - }); + return true; + } + + /// Apply the matching algorithm. + /// + /// @return true if the matching algorithm managed to fully match all the + /// nodes + bool match() { + std::lock_guard lockGuard(mutex); + + if (allNodesMatched()) { + return true; + } + + bool success; + bool complete; + + do { + success = matchIteration(); + complete = allNodesMatched(); - return success; + LLVM_DEBUG({ + llvm::dbgs() << "Match iteration completed\n"; + dump(llvm::dbgs()); + }); + } while (success && !complete); + + LLVM_DEBUG({ + if (success) { + llvm::dbgs() << "Matching completed successfully\n"; + } else { + llvm::dbgs() << "Matching failed\n"; } + }); - /// Get the solution of the matching problem. - bool getMatch(llvm::SmallVectorImpl& result) const - { - std::lock_guard lockGuard(mutex); + return success; + } - if (!allNodesMatched()) { - return false; - } + /// Get the solution of the matching problem. + bool getMatch(llvm::SmallVectorImpl &result) const { + std::lock_guard lockGuard(mutex); - auto equations = llvm::make_range( - getEquationsBeginIt(), getEquationsEndIt()); + if (!allNodesMatched()) { + return false; + } - for (VertexDescriptor equationDescriptor : equations) { - auto edges = llvm::make_range( - edgesBegin(equationDescriptor), edgesEnd(equationDescriptor)); + auto equations = + llvm::make_range(getEquationsBeginIt(), getEquationsEndIt()); - for (EdgeDescriptor edgeDescriptor : edges) { - const Edge& edge = graph[edgeDescriptor]; + for (VertexDescriptor equationDescriptor : equations) { + auto edges = llvm::make_range(edgesBegin(equationDescriptor), + edgesEnd(equationDescriptor)); - if (const auto& matched = edge.getMatched(); !matched.empty()) { - auto variableDescriptor = - edgeDescriptor.from == equationDescriptor - ? edgeDescriptor.to : edgeDescriptor.from; + for (EdgeDescriptor edgeDescriptor : edges) { + const Edge &edge = graph[edgeDescriptor]; - IndexSet matchedEquationIndices = matched.flattenColumns(); - assert(!matchedEquationIndices.empty()); + if (const auto &matched = edge.getMatched(); !matched.empty()) { + auto variableDescriptor = edgeDescriptor.from == equationDescriptor + ? edgeDescriptor.to + : edgeDescriptor.from; - result.emplace_back( - getEquationFromDescriptor(equationDescriptor).getProperty(), - getVariableFromDescriptor(variableDescriptor).getProperty(), - std::move(matchedEquationIndices), - edge.getAccessProperty()); - } - } - } + IndexSet matchedEquationIndices = matched.flattenColumns(); + assert(!matchedEquationIndices.empty()); - return true; + result.emplace_back( + getEquationFromDescriptor(equationDescriptor).getProperty(), + getVariableFromDescriptor(variableDescriptor).getProperty(), + std::move(matchedEquationIndices), edge.getAccessProperty()); + } } + } - private: - /// Check if a variable with a given ID exists. - bool hasVariableWithId(typename Variable::Id id) const - { - return variablesMap.find(id) != variablesMap.end(); - } + return true; + } - bool isVariable(VertexDescriptor vertex) const - { - return std::holds_alternative(graph[vertex]); - } +private: + /// Check if a variable with a given ID exists. + bool hasVariableWithId(typename Variable::Id id) const { + return variablesMap.find(id) != variablesMap.end(); + } - VertexDescriptor getVariableDescriptorFromId(typename Variable::Id id) const - { - auto it = variablesMap.find(id); - assert(it != variablesMap.end() && "Variable not found"); - return it->second; - } + bool isVariable(VertexDescriptor vertex) const { + return std::holds_alternative(graph[vertex]); + } - VariableProperty& getVariablePropertyFromId(typename Variable::Id id) - { - VertexDescriptor vertex = getVariableDescriptorFromId(id); - return std::get(graph[vertex]).getProperty(); - } + VertexDescriptor getVariableDescriptorFromId(typename Variable::Id id) const { + auto it = variablesMap.find(id); + assert(it != variablesMap.end() && "Variable not found"); + return it->second; + } - const VariableProperty& getVariablePropertyFromId(typename Variable::Id id) const - { - VertexDescriptor vertex = getVariableDescriptorFromId(id); - return std::get(graph[vertex]).getProperty(); - } + VariableProperty &getVariablePropertyFromId(typename Variable::Id id) { + VertexDescriptor vertex = getVariableDescriptorFromId(id); + return std::get(graph[vertex]).getProperty(); + } - Variable& getVariableFromDescriptor(VertexDescriptor descriptor) - { - Vertex& vertex = graph[descriptor]; - assert(std::holds_alternative(vertex)); - return std::get(vertex); - } + const VariableProperty & + getVariablePropertyFromId(typename Variable::Id id) const { + VertexDescriptor vertex = getVariableDescriptorFromId(id); + return std::get(graph[vertex]).getProperty(); + } - const Variable& getVariableFromDescriptor(VertexDescriptor descriptor) const - { - const Vertex& vertex = graph[descriptor]; - assert(std::holds_alternative(vertex)); - return std::get(vertex); - } + Variable &getVariableFromDescriptor(VertexDescriptor descriptor) { + Vertex &vertex = graph[descriptor]; + assert(std::holds_alternative(vertex)); + return std::get(vertex); + } - /// Check if an equation with a given ID exists. - bool hasEquationWithId(typename Equation::Id id) const - { - return equationsMap.find(id) != equationsMap.end(); - } + const Variable &getVariableFromDescriptor(VertexDescriptor descriptor) const { + const Vertex &vertex = graph[descriptor]; + assert(std::holds_alternative(vertex)); + return std::get(vertex); + } - bool isEquation(VertexDescriptor vertex) const - { - return std::holds_alternative(graph[vertex]); - } + /// Check if an equation with a given ID exists. + bool hasEquationWithId(typename Equation::Id id) const { + return equationsMap.find(id) != equationsMap.end(); + } - VertexDescriptor getEquationDescriptorFromId(typename Equation::Id id) const - { - auto it = equationsMap.find(id); - assert(it != equationsMap.end() && "Equation not found"); - return it->second; - } + bool isEquation(VertexDescriptor vertex) const { + return std::holds_alternative(graph[vertex]); + } - EquationProperty& getEquationPropertyFromId(typename Equation::Id id) - { - VertexDescriptor vertex = getEquationDescriptorFromId(id); - return std::get(graph[vertex]).getProperty(); - } + VertexDescriptor getEquationDescriptorFromId(typename Equation::Id id) const { + auto it = equationsMap.find(id); + assert(it != equationsMap.end() && "Equation not found"); + return it->second; + } - const EquationProperty& getEquationPropertyFromId(typename Equation::Id id) const - { - VertexDescriptor vertex = getEquationDescriptorFromId(id); - return std::get(graph[vertex]).getProperty(); - } + EquationProperty &getEquationPropertyFromId(typename Equation::Id id) { + VertexDescriptor vertex = getEquationDescriptorFromId(id); + return std::get(graph[vertex]).getProperty(); + } - Equation& getEquationFromDescriptor(VertexDescriptor descriptor) - { - Vertex& vertex = graph[descriptor]; - assert(std::holds_alternative(vertex)); - return std::get(vertex); - } + const EquationProperty & + getEquationPropertyFromId(typename Equation::Id id) const { + VertexDescriptor vertex = getEquationDescriptorFromId(id); + return std::get(graph[vertex]).getProperty(); + } - const Equation& getEquationFromDescriptor(VertexDescriptor descriptor) const - { - const Vertex& vertex = graph[descriptor]; - assert(std::holds_alternative(vertex)); - return std::get(vertex); - } + Equation &getEquationFromDescriptor(VertexDescriptor descriptor) { + Vertex &vertex = graph[descriptor]; + assert(std::holds_alternative(vertex)); + return std::get(vertex); + } - /// Get the begin iterator for the variables of the graph. - VariableIterator getVariablesBeginIt() const - { - auto filter = [](const Vertex& vertex) -> bool { - return std::holds_alternative(vertex); - }; + const Equation &getEquationFromDescriptor(VertexDescriptor descriptor) const { + const Vertex &vertex = graph[descriptor]; + assert(std::holds_alternative(vertex)); + return std::get(vertex); + } - return graph.verticesBegin(filter); - } + /// Get the begin iterator for the variables of the graph. + VariableIterator getVariablesBeginIt() const { + auto filter = [](const Vertex &vertex) -> bool { + return std::holds_alternative(vertex); + }; - /// Get the end iterator for the variables of the graph. - VariableIterator getVariablesEndIt() const - { - auto filter = [](const Vertex& vertex) -> bool { - return std::holds_alternative(vertex); - }; + return graph.verticesBegin(filter); + } - return graph.verticesEnd(filter); - } + /// Get the end iterator for the variables of the graph. + VariableIterator getVariablesEndIt() const { + auto filter = [](const Vertex &vertex) -> bool { + return std::holds_alternative(vertex); + }; - /// Get the begin iterator for the equations of the graph. - EquationIterator getEquationsBeginIt() const - { - auto filter = [](const Vertex& vertex) -> bool { - return std::holds_alternative(vertex); - }; + return graph.verticesEnd(filter); + } - return graph.verticesBegin(filter); - } + /// Get the begin iterator for the equations of the graph. + EquationIterator getEquationsBeginIt() const { + auto filter = [](const Vertex &vertex) -> bool { + return std::holds_alternative(vertex); + }; - /// Get the end iterator for the equations of the graph. - EquationIterator getEquationsEndIt() const - { - auto filter = [](const Vertex& vertex) -> bool { - return std::holds_alternative(vertex); - }; + return graph.verticesBegin(filter); + } - return graph.verticesEnd(filter); - } + /// Get the end iterator for the equations of the graph. + EquationIterator getEquationsEndIt() const { + auto filter = [](const Vertex &vertex) -> bool { + return std::holds_alternative(vertex); + }; - /// Check if all the scalar variables and equations have been matched. - bool allNodesMatched() const - { - auto allComponentsMatchedFn = [](const auto& obj) { - return obj.allComponentsMatched(); - }; - - return mlir::succeeded(mlir::failableParallelForEach( - getContext(), graph.verticesBegin(), graph.verticesEnd(), - [&](VertexDescriptor vertex) -> mlir::LogicalResult { - return mlir::LogicalResult::success( - std::visit(allComponentsMatchedFn, graph[vertex])); - })); - } + return graph.verticesEnd(filter); + } - size_t getVertexVisibilityDegree(VertexDescriptor vertex) const - { - auto edges = llvm::make_range(visibleEdgesBegin(vertex), visibleEdgesEnd(vertex)); - return std::distance(edges.begin(), edges.end()); - } + /// Check if all the scalar variables and equations have been matched. + bool allNodesMatched() const { + auto allComponentsMatchedFn = [](const auto &obj) { + return obj.allComponentsMatched(); + }; - void remove(VertexDescriptor vertex) - { - std::visit( - [](auto& obj) -> void { - obj.setVisibility(false); - }, graph[vertex]); - } + return mlir::succeeded(mlir::failableParallelForEach( + getContext(), graph.verticesBegin(), graph.verticesEnd(), + [&](VertexDescriptor vertex) -> mlir::LogicalResult { + return mlir::LogicalResult::success( + std::visit(allComponentsMatchedFn, graph[vertex])); + })); + } - // Warning: highly inefficient, use for testing purposes only. - template - std::pair findEdge(typename From::Id from, typename To::Id to) const - { - auto edges = llvm::make_range(graph.edgesBegin(), graph.edgesEnd()); + size_t getVertexVisibilityDegree(VertexDescriptor vertex) const { + auto edges = + llvm::make_range(visibleEdgesBegin(vertex), visibleEdgesEnd(vertex)); + return std::distance(edges.begin(), edges.end()); + } - EdgeIterator it = std::find_if( - edges.begin(), edges.end(), [&](const EdgeDescriptor& e) { - const Vertex& source = graph[e.from]; - const Vertex& target = graph[e.to]; + void remove(VertexDescriptor vertex) { + std::visit([](auto &obj) -> void { obj.setVisibility(false); }, + graph[vertex]); + } - if (!std::holds_alternative(source) || !std::holds_alternative(target)) { - return false; - } + // Warning: highly inefficient, use for testing purposes only. + template + std::pair findEdge(typename From::Id from, + typename To::Id to) const { + auto edges = llvm::make_range(graph.edgesBegin(), graph.edgesEnd()); - return std::get(source).getId() == from && std::get(target).getId() == to; - }); + EdgeIterator it = + std::find_if(edges.begin(), edges.end(), [&](const EdgeDescriptor &e) { + const Vertex &source = graph[e.from]; + const Vertex &target = graph[e.to]; - return std::make_pair(it != edges.end(), it); - } + if (!std::holds_alternative(source) || + !std::holds_alternative(target)) { + return false; + } - auto edgesBegin(VertexDescriptor vertex) const - { - return graph.outgoingEdgesBegin(vertex); - } + return std::get(source).getId() == from && + std::get(target).getId() == to; + }); - auto edgesEnd(VertexDescriptor vertex) const - { - return graph.outgoingEdgesEnd(vertex); - } + return std::make_pair(it != edges.end(), it); + } - VisibleIncidentEdgeIterator visibleEdgesBegin(VertexDescriptor vertex) const - { - auto filter = [&](const Edge& edge) -> bool { - return edge.isVisible(); - }; + auto edgesBegin(VertexDescriptor vertex) const { + return graph.outgoingEdgesBegin(vertex); + } - return graph.outgoingEdgesBegin(vertex, filter); - } + auto edgesEnd(VertexDescriptor vertex) const { + return graph.outgoingEdgesEnd(vertex); + } - VisibleIncidentEdgeIterator visibleEdgesEnd(VertexDescriptor vertex) const - { - auto filter = [&](const Edge& edge) -> bool { - return edge.isVisible(); - }; + VisibleIncidentEdgeIterator visibleEdgesBegin(VertexDescriptor vertex) const { + auto filter = [&](const Edge &edge) -> bool { return edge.isVisible(); }; - return graph.outgoingEdgesEnd(vertex, filter); - } + return graph.outgoingEdgesBegin(vertex, filter); + } - EdgeDescriptor getFirstOutVisibleEdge(VertexDescriptor vertex) const - { - auto edges = llvm::make_range(visibleEdgesBegin(vertex), visibleEdgesEnd(vertex)); - assert(edges.begin() != edges.end() && "Vertex doesn't belong to any edge"); - return *edges.begin(); - } + VisibleIncidentEdgeIterator visibleEdgesEnd(VertexDescriptor vertex) const { + auto filter = [&](const Edge &edge) -> bool { return edge.isVisible(); }; - void remove(EdgeDescriptor edge) - { - graph[edge].setVisibility(false); - } + return graph.outgoingEdgesEnd(vertex, filter); + } - /// Collect the list of vertices with exactly one incident edge. - /// The function returns 'false' if there exist a node with no incident - /// edges (which would make the matching process to fail in aby case). - bool collectSimplifiableNodes(std::list& nodes) const - { - std::mutex resultMutex; + EdgeDescriptor getFirstOutVisibleEdge(VertexDescriptor vertex) const { + auto edges = + llvm::make_range(visibleEdgesBegin(vertex), visibleEdgesEnd(vertex)); + assert(edges.begin() != edges.end() && "Vertex doesn't belong to any edge"); + return *edges.begin(); + } - auto collectFn = [&](VertexDescriptor vertex) -> mlir::LogicalResult { - size_t incidentEdges = getVertexVisibilityDegree(vertex); + void remove(EdgeDescriptor edge) { graph[edge].setVisibility(false); } - if (incidentEdges == 0) { - return mlir::failure(); - } + /// Collect the list of vertices with exactly one incident edge. + /// The function returns 'false' if there exist a node with no incident + /// edges (which would make the matching process to fail in aby case). + bool collectSimplifiableNodes(std::list &nodes) const { + std::mutex resultMutex; - if (incidentEdges == 1) { - std::lock_guard resultLockGuard(resultMutex); - nodes.push_back(vertex); - } + auto collectFn = [&](VertexDescriptor vertex) -> mlir::LogicalResult { + size_t incidentEdges = getVertexVisibilityDegree(vertex); - return mlir::success(); - }; + if (incidentEdges == 0) { + return mlir::failure(); + } - return mlir::succeeded(mlir::failableParallelForEach( - getContext(), - graph.verticesBegin(), graph.verticesEnd(), - collectFn)); + if (incidentEdges == 1) { + std::lock_guard resultLockGuard(resultMutex); + nodes.push_back(vertex); } - bool matchIteration() - { - llvm::SmallVector augmentingPaths; - getAugmentingPaths(augmentingPaths); + return mlir::success(); + }; - if (augmentingPaths.empty()) { - return false; - } + return mlir::succeeded(mlir::failableParallelForEach( + getContext(), graph.verticesBegin(), graph.verticesEnd(), collectFn)); + } - for (auto& path : augmentingPaths) { - applyPath(path); - } + bool matchIteration() { + llvm::SmallVector augmentingPaths; + getAugmentingPaths(augmentingPaths); - return true; - } + if (augmentingPaths.empty()) { + return false; + } - void getAugmentingPaths(llvm::SmallVectorImpl& paths) const - { - auto sortHeuristic = [](const BFSStep& first, const BFSStep& second) { - return first.getCandidates().flatSize() > second.getCandidates().flatSize(); - }; + for (auto &path : augmentingPaths) { + applyPath(path); + } - Frontier frontier; + return true; + } - // Calculation of the initial frontier - auto equations = llvm::make_range(getEquationsBeginIt(), getEquationsEndIt()); + void getAugmentingPaths(llvm::SmallVectorImpl &paths) const { + auto sortHeuristic = [](const BFSStep &first, const BFSStep &second) { + return first.getCandidates().flatSize() > + second.getCandidates().flatSize(); + }; - for (VertexDescriptor equationDescriptor : equations) { - const Equation& equation = getEquationFromDescriptor(equationDescriptor); + Frontier frontier; - if (auto unmatchedEquations = equation.getUnmatched(); !unmatchedEquations.empty()) { - frontier.emplace(BFSStep(graph, equationDescriptor, std::move(unmatchedEquations))); - } - } + // Calculation of the initial frontier + auto equations = + llvm::make_range(getEquationsBeginIt(), getEquationsEndIt()); - llvm::sort(frontier, sortHeuristic); - - // Breadth-first search - Frontier newFrontier; - Frontier foundPaths; - - while (!frontier.empty() && foundPaths.empty()) { - for (const BFSStep& step : frontier) { - const VertexDescriptor& vertexDescriptor = step.getNode(); - - for (EdgeDescriptor edgeDescriptor : llvm::make_range(edgesBegin(vertexDescriptor), edgesEnd(vertexDescriptor))) { - assert(edgeDescriptor.from == vertexDescriptor); - VertexDescriptor nextNode = edgeDescriptor.to; - const Edge& edge = graph[edgeDescriptor]; - - if (isEquation(vertexDescriptor)) { - assert(isVariable(nextNode)); - auto unmatchedMatrix = edge.getUnmatched(); - auto filteredMatrix = unmatchedMatrix.filterRows(step.getCandidates()); - internal::LocalMatchingSolutions solutions = internal::solveLocalMatchingProblem(filteredMatrix); - - for (auto solution : solutions) { - Variable var = getVariableFromDescriptor(edgeDescriptor.to); - auto unmatchedScalarVariables = var.getUnmatched(); - auto matched = solution.filterColumns(unmatchedScalarVariables); - - if (!matched.empty()) { - foundPaths.emplace(graph, step, edgeDescriptor, nextNode, matched.flattenRows(), matched); - } else { - newFrontier.emplace(graph, step, edgeDescriptor, nextNode, solution.flattenRows(), solution); - } - } - } else { - assert(isEquation(nextNode)); - auto filteredMatrix = edge.getMatched().filterColumns(step.getCandidates()); - internal::LocalMatchingSolutions solutions = internal::solveLocalMatchingProblem(filteredMatrix); + for (VertexDescriptor equationDescriptor : equations) { + const Equation &equation = getEquationFromDescriptor(equationDescriptor); + + if (auto unmatchedEquations = equation.getUnmatched(); + !unmatchedEquations.empty()) { + frontier.emplace( + BFSStep(graph, equationDescriptor, std::move(unmatchedEquations))); + } + } - for (auto solution : solutions) { - newFrontier.emplace(graph, step, edgeDescriptor, nextNode, solution.flattenColumns(), solution); - } + llvm::sort(frontier, sortHeuristic); + + // Breadth-first search + Frontier newFrontier; + Frontier foundPaths; + + while (!frontier.empty() && foundPaths.empty()) { + for (const BFSStep &step : frontier) { + const VertexDescriptor &vertexDescriptor = step.getNode(); + + for (EdgeDescriptor edgeDescriptor : llvm::make_range( + edgesBegin(vertexDescriptor), edgesEnd(vertexDescriptor))) { + assert(edgeDescriptor.from == vertexDescriptor); + VertexDescriptor nextNode = edgeDescriptor.to; + const Edge &edge = graph[edgeDescriptor]; + + if (isEquation(vertexDescriptor)) { + assert(isVariable(nextNode)); + auto unmatchedMatrix = edge.getUnmatched(); + auto filteredMatrix = + unmatchedMatrix.filterRows(step.getCandidates()); + internal::LocalMatchingSolutions solutions = + internal::solveLocalMatchingProblem(filteredMatrix); + + for (auto solution : solutions) { + Variable var = getVariableFromDescriptor(edgeDescriptor.to); + auto unmatchedScalarVariables = var.getUnmatched(); + auto matched = solution.filterColumns(unmatchedScalarVariables); + + if (!matched.empty()) { + foundPaths.emplace(graph, step, edgeDescriptor, nextNode, + matched.flattenRows(), matched); + } else { + newFrontier.emplace(graph, step, edgeDescriptor, nextNode, + solution.flattenRows(), solution); } } + } else { + assert(isEquation(nextNode)); + auto filteredMatrix = + edge.getMatched().filterColumns(step.getCandidates()); + internal::LocalMatchingSolutions solutions = + internal::solveLocalMatchingProblem(filteredMatrix); + + for (auto solution : solutions) { + newFrontier.emplace(graph, step, edgeDescriptor, nextNode, + solution.flattenColumns(), solution); + } } - - // Set the new frontier for the next iteration - frontier.clear(); - frontier.swap(newFrontier); - - llvm::sort(frontier, sortHeuristic); } + } - llvm::sort(foundPaths, sortHeuristic); - - // For each traversed node, keep track of the indexes that have already - // been traversed by some augmenting path. A new candidate path can be - // accepted only if it does not traverse any of them. - std::map visited; - - for (const BFSStep& pathEnd : foundPaths) { - // All the candidate paths consist in at least two nodes by construction - assert(pathEnd.hasPrevious()); + // Set the new frontier for the next iteration + frontier.clear(); + frontier.swap(newFrontier); - std::list flows; + llvm::sort(frontier, sortHeuristic); + } - // The path's validity is unknown, so we must avoid polluting the - // list of visited scalar nodes. If the path will be marked as valid, - // then the new visits will be merged with the already found ones. - std::map newVisits; + llvm::sort(foundPaths, sortHeuristic); - const BFSStep* curStep = &pathEnd; - MCIM map = curStep->getMappedFlow(); - bool validPath = true; + // For each traversed node, keep track of the indexes that have already + // been traversed by some augmenting path. A new candidate path can be + // accepted only if it does not traverse any of them. + std::map visited; - while (curStep && validPath) { - if (curStep->hasPrevious()) { - if (!flows.empty()) { - // Restrict the flow - const auto& prevMap = flows.front().delta; + for (const BFSStep &pathEnd : foundPaths) { + // All the candidate paths consist in at least two nodes by construction + assert(pathEnd.hasPrevious()); - if (isVariable(curStep->getNode())) { - map = curStep->getMappedFlow().filterColumns(prevMap.flattenRows()); - } else { - map = curStep->getMappedFlow().filterRows(prevMap.flattenColumns()); - } - } + std::list flows; - flows.emplace(flows.begin(), graph, curStep->getPrevious()->getNode(), curStep->getEdge(), map); - } + // The path's validity is unknown, so we must avoid polluting the + // list of visited scalar nodes. If the path will be marked as valid, + // then the new visits will be merged with the already found ones. + std::map newVisits; - auto touchedIndexes = isVariable(curStep->getNode()) ? map.flattenRows() : map.flattenColumns(); + const BFSStep *curStep = &pathEnd; + MCIM map = curStep->getMappedFlow(); + bool validPath = true; - if (auto it = visited.find(curStep->getNode()); it != visited.end()) { - auto& alreadyTouchedIndices = it->second; + while (curStep && validPath) { + if (curStep->hasPrevious()) { + if (!flows.empty()) { + // Restrict the flow + const auto &prevMap = flows.front().delta; - if (touchedIndexes.overlaps(alreadyTouchedIndices)) { - // The current path intersects another one, so we need to discard it - validPath = false; - } else { - insertOrAdd(newVisits, curStep->getNode(), alreadyTouchedIndices + touchedIndexes); - } + if (isVariable(curStep->getNode())) { + map = + curStep->getMappedFlow().filterColumns(prevMap.flattenRows()); } else { - insertOrAdd(newVisits, curStep->getNode(), touchedIndexes); + map = + curStep->getMappedFlow().filterRows(prevMap.flattenColumns()); } - - // Move backwards inside the candidate augmenting path - curStep = curStep->getPrevious(); } - if (validPath) { - paths.emplace_back(std::move(flows)); - - for (auto& p : newVisits) { - visited.insert_or_assign(p.first, p.second); - } - } + flows.emplace(flows.begin(), graph, curStep->getPrevious()->getNode(), + curStep->getEdge(), map); } - } - /// Apply an augmenting path to the graph. - void applyPath(const AugmentingPath& path) - { - // In order to preserve consistency of the match information among - // edges and nodes, we need to separately track the modifications - // created by the augmenting path on the vertices and apply all the - // removals before the additions. - // Consider in fact the example path [eq1 -> x -> eq2]: the first - // move would add some match information to eq1 and x, while the - // subsequent x -> eq2 would remove some from x. However, being the - // match matrices made of booleans, the components of x that are - // matched by eq1 would result as unmatched. If we instead first - // apply the removals, the new matches are not wrongly erased anymore. - - std::map removedMatches; - std::map newMatches; - - // Update the match matrices on the edges and store the information - // about the vertices to be updated later. - - for (auto& flow : path) { - Edge& edge = graph[flow.edge]; - - VertexDescriptor from = flow.source; - VertexDescriptor to = flow.edge.from == from ? flow.edge.to : flow.edge.from; - - auto deltaEquations = flow.delta.flattenColumns(); - auto deltaVariables = flow.delta.flattenRows(); - - if (isVariable(from)) { - // Backward node - insertOrAdd(removedMatches, from, deltaVariables); - insertOrAdd(removedMatches, to, deltaEquations); - edge.removeMatch(flow.delta); + auto touchedIndexes = isVariable(curStep->getNode()) + ? map.flattenRows() + : map.flattenColumns(); + + if (auto it = visited.find(curStep->getNode()); it != visited.end()) { + auto &alreadyTouchedIndices = it->second; + + if (touchedIndexes.overlaps(alreadyTouchedIndices)) { + // The current path intersects another one, so we need to discard it + validPath = false; } else { - // Forward node - insertOrAdd(newMatches, from, deltaEquations); - insertOrAdd(newMatches, to, deltaVariables); - edge.addMatch(flow.delta); + insertOrAdd(newVisits, curStep->getNode(), + alreadyTouchedIndices + touchedIndexes); } + } else { + insertOrAdd(newVisits, curStep->getNode(), touchedIndexes); } - // Update the match information stored on the vertices + // Move backwards inside the candidate augmenting path + curStep = curStep->getPrevious(); + } - for (const auto& match : removedMatches) { - std::visit( - [&match](auto& node) { - node.removeMatch(match.second); - }, graph[match.first]); - } + if (validPath) { + paths.emplace_back(std::move(flows)); - for (const auto& match : newMatches) { - std::visit( - [&match](auto& node) { - node.addMatch(match.second); - }, graph[match.first]); + for (auto &p : newVisits) { + visited.insert_or_assign(p.first, p.second); } } + } + } + + /// Apply an augmenting path to the graph. + void applyPath(const AugmentingPath &path) { + // In order to preserve consistency of the match information among + // edges and nodes, we need to separately track the modifications + // created by the augmenting path on the vertices and apply all the + // removals before the additions. + // Consider in fact the example path [eq1 -> x -> eq2]: the first + // move would add some match information to eq1 and x, while the + // subsequent x -> eq2 would remove some from x. However, being the + // match matrices made of booleans, the components of x that are + // matched by eq1 would result as unmatched. If we instead first + // apply the removals, the new matches are not wrongly erased anymore. + + std::map removedMatches; + std::map newMatches; + + // Update the match matrices on the edges and store the information + // about the vertices to be updated later. + + for (auto &flow : path) { + Edge &edge = graph[flow.edge]; + + VertexDescriptor from = flow.source; + VertexDescriptor to = + flow.edge.from == from ? flow.edge.to : flow.edge.from; + + auto deltaEquations = flow.delta.flattenColumns(); + auto deltaVariables = flow.delta.flattenRows(); + + if (isVariable(from)) { + // Backward node + insertOrAdd(removedMatches, from, deltaVariables); + insertOrAdd(removedMatches, to, deltaEquations); + edge.removeMatch(flow.delta); + } else { + // Forward node + insertOrAdd(newMatches, from, deltaEquations); + insertOrAdd(newMatches, to, deltaVariables); + edge.addMatch(flow.delta); + } + } - private: - mlir::MLIRContext* context; - Graph graph; + // Update the match information stored on the vertices - // Maps user for faster lookups. - std::map variablesMap; - std::map equationsMap; + for (const auto &match : removedMatches) { + std::visit([&match](auto &node) { node.removeMatch(match.second); }, + graph[match.first]); + } - // Multithreading. - mutable std::mutex mutex; - }; -} + for (const auto &match : newMatches) { + std::visit([&match](auto &node) { node.addMatch(match.second); }, + graph[match.first]); + } + } + +private: + mlir::MLIRContext *context; + Graph graph; + + // Maps user for faster lookups. + std::map variablesMap; + std::map equationsMap; + + // Multithreading. + mutable std::mutex mutex; +}; +} // namespace marco::modeling -#endif // MARCO_MODELING_MATCHING_H +#endif // MARCO_MODELING_MATCHING_H diff --git a/lib/Codegen/Conversion/BaseModelicaToTensor/BaseModelicaToTensor.cpp b/lib/Codegen/Conversion/BaseModelicaToTensor/BaseModelicaToTensor.cpp index 278909d16..839042ffc 100644 --- a/lib/Codegen/Conversion/BaseModelicaToTensor/BaseModelicaToTensor.cpp +++ b/lib/Codegen/Conversion/BaseModelicaToTensor/BaseModelicaToTensor.cpp @@ -5,591 +5,527 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" -namespace mlir -{ +namespace mlir { #define GEN_PASS_DEF_BASEMODELICATOTENSORCONVERSIONPASS #include "marco/Codegen/Conversion/Passes.h.inc" -} +} // namespace mlir using namespace ::mlir::bmodelica; -namespace -{ - class BaseModelicaToTensorConversionPass - : public mlir::impl::BaseModelicaToTensorConversionPassBase< - BaseModelicaToTensorConversionPass> - { - public: - using BaseModelicaToTensorConversionPassBase< - BaseModelicaToTensorConversionPass> - ::BaseModelicaToTensorConversionPassBase; - - void runOnOperation() override; - - private: - mlir::LogicalResult convertOperations(); - }; -} +namespace { +class BaseModelicaToTensorConversionPass + : public mlir::impl::BaseModelicaToTensorConversionPassBase< + BaseModelicaToTensorConversionPass> { +public: + using BaseModelicaToTensorConversionPassBase< + BaseModelicaToTensorConversionPass>:: + BaseModelicaToTensorConversionPassBase; + + void runOnOperation() override; + +private: + mlir::LogicalResult convertOperations(); +}; +} // namespace -void BaseModelicaToTensorConversionPass::runOnOperation() -{ +void BaseModelicaToTensorConversionPass::runOnOperation() { if (mlir::failed(convertOperations())) { return signalPassFailure(); } } -namespace -{ - struct TensorFromElementsOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern - ::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - TensorFromElementsOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - auto resultType = getTypeConverter()->convertType( - op.getResult().getType()); - - auto resultTensorType = resultType.cast(); - auto resultElementType = resultTensorType.getElementType(); - - llvm::SmallVector operands; - - for (mlir::Value operand : adaptor.getValues()) { - if (operand.getType() != resultElementType) { - operand = rewriter.create( - op.getLoc(), resultElementType, operand); - } +namespace { +struct TensorFromElementsOpLowering + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - operands.push_back(operand); - } + mlir::LogicalResult + matchAndRewrite(TensorFromElementsOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOpWithNewOp( - op, resultTensorType, operands); + auto resultTensorType = resultType.cast(); + auto resultElementType = resultTensorType.getElementType(); - return mlir::success(); + llvm::SmallVector operands; + + for (mlir::Value operand : adaptor.getValues()) { + if (operand.getType() != resultElementType) { + operand = + rewriter.create(op.getLoc(), resultElementType, operand); + } + + operands.push_back(operand); } - }; - struct TensorBroadcastOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern - ::OpConversionPattern; + rewriter.replaceOpWithNewOp( + op, resultTensorType, operands); - mlir::LogicalResult matchAndRewrite( - TensorBroadcastOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - auto resultType = getTypeConverter()->convertType( - op.getResult().getType()); + return mlir::success(); + } +}; - auto resultTensorType = resultType.cast(); - auto resultElementType = resultTensorType.getElementType(); +struct TensorBroadcastOpLowering + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - mlir::Value operand = adaptor.getValue(); + mlir::LogicalResult + matchAndRewrite(TensorBroadcastOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); - if (operand.getType() != resultElementType) { - operand = rewriter.create( - op.getLoc(), resultElementType, operand); - } + auto resultTensorType = resultType.cast(); + auto resultElementType = resultTensorType.getElementType(); - rewriter.replaceOpWithNewOp( - op, resultTensorType, operand); + mlir::Value operand = adaptor.getValue(); - return mlir::success(); + if (operand.getType() != resultElementType) { + operand = + rewriter.create(op.getLoc(), resultElementType, operand); } - }; - - struct TensorViewOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - TensorViewOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - - llvm::SmallVector offsets; - llvm::SmallVector sizes; - llvm::SmallVector strides; - - for (mlir::Value subscript : adaptor.getSubscriptions()) { - if (subscript.getType().isa()) { - mlir::Value begin = rewriter.create(loc, subscript); - mlir::Value size = rewriter.create(loc, subscript); - mlir::Value step = rewriter.create(loc, subscript); - - if (!begin.getType().isa()) { - begin = rewriter.create( - begin.getLoc(), rewriter.getIndexType(), begin); - } - - if (!size.getType().isa()) { - size = rewriter.create( - size.getLoc(), rewriter.getIndexType(), size); - } - - if (!step.getType().isa()) { - step = rewriter.create( - step.getLoc(), rewriter.getIndexType(), step); - } - - offsets.push_back(begin); - sizes.push_back(size); - strides.push_back(step); - } else { - offsets.push_back(subscript); - sizes.push_back(rewriter.getI64IntegerAttr(1)); - strides.push_back(rewriter.getI64IntegerAttr(1)); - } - } - auto numOfSubscripts = static_cast( - adaptor.getSubscriptions().size()); + rewriter.replaceOpWithNewOp(op, resultTensorType, + operand); - auto sourceTensorType = - adaptor.getSource().getType().cast(); + return mlir::success(); + } +}; - int64_t sourceRank = sourceTensorType.getRank(); +struct TensorViewOpLowering : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - for (int64_t i = numOfSubscripts; i < sourceRank; ++i) { - offsets.push_back(rewriter.getI64IntegerAttr(0)); - int64_t sourceDimension = sourceTensorType.getDimSize(i); + mlir::LogicalResult + matchAndRewrite(TensorViewOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); - if (sourceDimension == mlir::ShapedType::kDynamic) { - mlir::Value dimensionSize = rewriter.create( - loc, adaptor.getSource(), i); + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; - sizes.push_back(dimensionSize); - } else { - sizes.push_back(rewriter.getI64IntegerAttr(sourceDimension)); + for (mlir::Value subscript : adaptor.getSubscriptions()) { + if (subscript.getType().isa()) { + mlir::Value begin = rewriter.create(loc, subscript); + mlir::Value size = rewriter.create(loc, subscript); + mlir::Value step = rewriter.create(loc, subscript); + + if (!begin.getType().isa()) { + begin = rewriter.create(begin.getLoc(), + rewriter.getIndexType(), begin); } + if (!size.getType().isa()) { + size = rewriter.create(size.getLoc(), rewriter.getIndexType(), + size); + } + + if (!step.getType().isa()) { + step = rewriter.create(step.getLoc(), rewriter.getIndexType(), + step); + } + + offsets.push_back(begin); + sizes.push_back(size); + strides.push_back(step); + } else { + offsets.push_back(subscript); + sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } + } + + auto numOfSubscripts = + static_cast(adaptor.getSubscriptions().size()); - mlir::Type requestedResultType = - getTypeConverter()->convertType(op.getResult().getType()); + auto sourceTensorType = + adaptor.getSource().getType().cast(); - auto requestedResultTensorType = - requestedResultType.cast(); + int64_t sourceRank = sourceTensorType.getRank(); - auto resultType = mlir::RankedTensorType::get( - requestedResultTensorType.getShape(), - sourceTensorType.getElementType()); + for (int64_t i = numOfSubscripts; i < sourceRank; ++i) { + offsets.push_back(rewriter.getI64IntegerAttr(0)); + int64_t sourceDimension = sourceTensorType.getDimSize(i); - mlir::Value result = rewriter.create( - loc, resultType, adaptor.getSource(), offsets, sizes, strides); + if (sourceDimension == mlir::ShapedType::kDynamic) { + mlir::Value dimensionSize = + rewriter.create(loc, adaptor.getSource(), i); - if (result.getType() != requestedResultType) { - result = rewriter.create( - result.getLoc(), requestedResultType, result); + sizes.push_back(dimensionSize); + } else { + sizes.push_back(rewriter.getI64IntegerAttr(sourceDimension)); } - rewriter.replaceOp(op, result); - return mlir::success(); + strides.push_back(rewriter.getI64IntegerAttr(1)); } - }; - - struct TensorExtractOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - TensorExtractOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - llvm::SmallVector indices; - - for (mlir::Value index : adaptor.getIndices()) { - if (!index.getType().isa()) { - index = rewriter.create( - loc, rewriter.getIndexType(), index); - } - indices.push_back(index); - } + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + + auto requestedResultTensorType = + requestedResultType.cast(); + + auto resultType = + mlir::RankedTensorType::get(requestedResultTensorType.getShape(), + sourceTensorType.getElementType()); + + mlir::Value result = rewriter.create( + loc, resultType, adaptor.getSource(), offsets, sizes, strides); + + if (result.getType() != requestedResultType) { + result = rewriter.create( + result.getLoc(), requestedResultType, result); + } + + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; - mlir::Value result = rewriter.create( - loc, adaptor.getTensor(), indices); +struct TensorExtractOpLowering + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - mlir::Type requestedResultType = getTypeConverter()->convertType( - op.getResult().getType()); + mlir::LogicalResult + matchAndRewrite(TensorExtractOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + llvm::SmallVector indices; - if (result.getType() != requestedResultType) { - result = rewriter.create(loc, requestedResultType, result); + for (mlir::Value index : adaptor.getIndices()) { + if (!index.getType().isa()) { + index = rewriter.create(loc, rewriter.getIndexType(), index); } - rewriter.replaceOp(op, result); - return mlir::success(); + indices.push_back(index); } - }; - struct TensorInsertOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; + mlir::Value result = rewriter.create( + loc, adaptor.getTensor(), indices); - mlir::LogicalResult matchAndRewrite( - TensorInsertOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); - mlir::Value tensor = adaptor.getDestination(); - auto tensorType = tensor.getType().cast(); - mlir::Type elementType = tensorType.getElementType(); + if (result.getType() != requestedResultType) { + result = rewriter.create(loc, requestedResultType, result); + } - mlir::Value value = adaptor.getValue(); + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; - if (value.getType() != elementType) { - value = rewriter.create(loc, elementType, value); - } +struct TensorInsertOpLowering + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - llvm::SmallVector indices; + mlir::LogicalResult + matchAndRewrite(TensorInsertOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); - for (mlir::Value index : adaptor.getIndices()) { - if (!index.getType().isa()) { - index = rewriter.create( - loc, rewriter.getIndexType(), index); - } + mlir::Value tensor = adaptor.getDestination(); + auto tensorType = tensor.getType().cast(); + mlir::Type elementType = tensorType.getElementType(); - indices.push_back(index); - } + mlir::Value value = adaptor.getValue(); - mlir::Value result = rewriter.create( - loc, value, tensor, indices); + if (value.getType() != elementType) { + value = rewriter.create(loc, elementType, value); + } - mlir::Type requestedResultType = getTypeConverter()->convertType( - op.getResult().getType()); + llvm::SmallVector indices; - if (result.getType() != requestedResultType) { - result = rewriter.create(loc, requestedResultType, result); + for (mlir::Value index : adaptor.getIndices()) { + if (!index.getType().isa()) { + index = rewriter.create(loc, rewriter.getIndexType(), index); } - rewriter.replaceOp(op, result); - return mlir::success(); + indices.push_back(index); } - }; - - struct TensorInsertSliceOpLowering - : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern - ::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - TensorInsertSliceOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - - mlir::Value source = adaptor.getValue(); - auto sourceTensorType = source.getType().cast(); - - mlir::Value destination = adaptor.getDestination(); - - auto destinationTensorType = - destination.getType().cast(); - - auto subscriptions = adaptor.getSubscriptions(); - size_t numOfSubscriptions = subscriptions.size(); - size_t subscriptionsIndex = 0; - - int64_t sourceDimension = 0; - int64_t destinationRank = destinationTensorType.getRank(); - - llvm::SmallVector offsets; - llvm::SmallVector sizes; - llvm::SmallVector strides; - - // Utility function to get a known dimension or create the operation to - // obtain it at runtime. - auto getDimSizeFn = - [&](mlir::Value tensor, int64_t dim) -> mlir::OpFoldResult { - auto tensorType = tensor.getType().cast(); - assert(dim < tensorType.getRank()); - - if (auto dimSize = tensorType.getDimSize(dim); - dimSize != mlir::ShapedType::kDynamic) { - return rewriter.getI64IntegerAttr(dimSize); - } - - auto dimOp = rewriter.create( - loc, tensor, dim); - - return dimOp.getResult(); - }; - - // The source may have a rank smaller than the destination, so we iterate - // on the destination rank. - for (int64_t destinationDim = 0; destinationDim < destinationRank; - ++destinationDim) { - if (subscriptionsIndex < numOfSubscriptions) { - mlir::Value subscription = subscriptions[subscriptionsIndex]; - - if (subscription.getType().isa()) { - // The offset is either the begin or the end of the range, - // depending on the step value. - // The size is given by the source dimension size. - // The stride is given by the step. - assert(sourceDimension < sourceTensorType.getRank()); - - mlir::Value beginValue = - rewriter.create(loc, subscription); - - mlir::Value endValue = - rewriter.create(loc, subscription); - - mlir::Value step = rewriter.create(loc, subscription); - - mlir::Value zero = rewriter.create( - loc, rewriter.getIndexAttr(0)); - - mlir::Value nonNegative = rewriter.create( - loc, mlir::arith::CmpIPredicate::sge, step, zero); - - mlir::Value offset = rewriter.create( - loc, nonNegative, beginValue, endValue); - - offsets.push_back(offset); - sizes.push_back(getDimSizeFn(source, sourceDimension++)); - strides.push_back(step); - } else { - // Use the subscription for reducing the rank of the destination - // and add additional unitary dimensions to the source. - offsets.push_back(subscription); - sizes.push_back(rewriter.getI64IntegerAttr(1)); - strides.push_back(rewriter.getI64IntegerAttr(1)); - } - - ++subscriptionsIndex; - } else { - // No more subscriptions available. - // The remaining dimensions are copied from the source into the - // destination. - assert(sourceDimension < sourceTensorType.getRank()); - offsets.push_back(rewriter.getI64IntegerAttr(0)); + mlir::Value result = + rewriter.create(loc, value, tensor, indices); + + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + + if (result.getType() != requestedResultType) { + result = rewriter.create(loc, requestedResultType, result); + } + + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + +struct TensorInsertSliceOpLowering + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(TensorInsertSliceOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + + mlir::Value source = adaptor.getValue(); + + mlir::Value destination = adaptor.getDestination(); + + auto destinationTensorType = destination.getType().cast(); + + auto subscriptions = adaptor.getSubscriptions(); + size_t numOfSubscriptions = subscriptions.size(); + size_t subscriptionsIndex = 0; + + int64_t sourceDimension = 0; + int64_t destinationRank = destinationTensorType.getRank(); + + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + + // Utility function to get a known dimension or create the operation to + // obtain it at runtime. + auto getDimSizeFn = [&](mlir::Value tensor, + int64_t dim) -> mlir::OpFoldResult { + auto tensorType = tensor.getType().cast(); + assert(dim < tensorType.getRank()); + + if (auto dimSize = tensorType.getDimSize(dim); + dimSize != mlir::ShapedType::kDynamic) { + return rewriter.getI64IntegerAttr(dimSize); + } + + auto dimOp = rewriter.create(loc, tensor, dim); + + return dimOp.getResult(); + }; + + // The source may have a rank smaller than the destination, so we iterate + // on the destination rank. + for (int64_t destinationDim = 0; destinationDim < destinationRank; + ++destinationDim) { + if (subscriptionsIndex < numOfSubscriptions) { + mlir::Value subscription = subscriptions[subscriptionsIndex]; + + if (subscription.getType().isa()) { + // The offset is either the begin or the end of the range, + // depending on the step value. + // The size is given by the source dimension size. + // The stride is given by the step. + assert(sourceDimension < + source.getType().cast().getRank()); + + mlir::Value beginValue = + rewriter.create(loc, subscription); + + mlir::Value endValue = rewriter.create(loc, subscription); + + mlir::Value step = rewriter.create(loc, subscription); + + mlir::Value zero = rewriter.create( + loc, rewriter.getIndexAttr(0)); + + mlir::Value nonNegative = rewriter.create( + loc, mlir::arith::CmpIPredicate::sge, step, zero); + + mlir::Value offset = rewriter.create( + loc, nonNegative, beginValue, endValue); + + offsets.push_back(offset); sizes.push_back(getDimSizeFn(source, sourceDimension++)); + strides.push_back(step); + } else { + // Use the subscription for reducing the rank of the destination + // and add additional unitary dimensions to the source. + offsets.push_back(subscription); + sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } - } - rewriter.replaceOpWithNewOp( - op, source, destination, offsets, sizes, strides); + ++subscriptionsIndex; + } else { + // No more subscriptions available. + // The remaining dimensions are copied from the source into the + // destination. + assert(sourceDimension < + source.getType().cast().getRank()); - return mlir::success(); + offsets.push_back(rewriter.getI64IntegerAttr(0)); + sizes.push_back(getDimSizeFn(source, sourceDimension++)); + strides.push_back(rewriter.getI64IntegerAttr(1)); + } } - }; - struct NDimsOpLowering : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; + rewriter.replaceOpWithNewOp( + op, source, destination, offsets, sizes, strides); - mlir::LogicalResult matchAndRewrite( - NDimsOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - auto tensorType = adaptor.getArray().getType().cast(); + return mlir::success(); + } +}; - mlir::Value result = rewriter.create( - loc, rewriter.getIndexAttr(tensorType.getRank())); +struct NDimsOpLowering : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - mlir::Type requestedResultType = - getTypeConverter()->convertType(op.getResult().getType()); + mlir::LogicalResult + matchAndRewrite(NDimsOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + auto tensorType = adaptor.getArray().getType().cast(); - if (result.getType() != requestedResultType) { - result = rewriter.create(loc, requestedResultType, result); - } + mlir::Value result = rewriter.create( + loc, rewriter.getIndexAttr(tensorType.getRank())); + + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOp(op, result); - return mlir::success(); + if (result.getType() != requestedResultType) { + result = rewriter.create(loc, requestedResultType, result); } - }; - - struct SizeOpDimensionLowering : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - SizeOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - mlir::Value tensor = op.getArray(); - - if (!op.hasDimension()) { - return rewriter.notifyMatchFailure(op, "No index specified"); - } - mlir::Value index = op.getDimension(); + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; - if (!index.getType().isa()) { - index = rewriter.create(loc, rewriter.getIndexType(), index); - } +struct SizeOpDimensionLowering : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - mlir::Value result = rewriter.create( - loc, tensor, index); + mlir::LogicalResult + matchAndRewrite(SizeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Value tensor = op.getArray(); - mlir::Type requestedResultType = - getTypeConverter()->convertType(op.getResult().getType()); + if (!op.hasDimension()) { + return rewriter.notifyMatchFailure(op, "No index specified"); + } - if (result.getType() != requestedResultType) { - result = rewriter.create(loc, requestedResultType, result); - } + mlir::Value index = op.getDimension(); - rewriter.replaceOp(op, result); - return mlir::success(); + if (!index.getType().isa()) { + index = rewriter.create(loc, rewriter.getIndexType(), index); } - }; - - struct SizeOpArrayLowering : public mlir::OpConversionPattern - { - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - SizeOp op, - OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override - { - mlir::Location loc = op.getLoc(); - mlir::Value tensor = adaptor.getArray(); - - if (op.hasDimension()) { - return rewriter.notifyMatchFailure(op, "Index specified"); - } - mlir::Type requestedResultType = - getTypeConverter()->convertType(op.getResult().getType()); + mlir::Value result = + rewriter.create(loc, tensor, index); - auto requestedResultTensorType = - requestedResultType.cast(); + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); - mlir::Type requestedResultElementType = - requestedResultTensorType.getElementType(); + if (result.getType() != requestedResultType) { + result = rewriter.create(loc, requestedResultType, result); + } - llvm::SmallVector dynamicDimensions; + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; - if (requestedResultTensorType.getDimSize(0) == - mlir::ShapedType::kDynamic) { - dynamicDimensions.push_back( - rewriter.create(loc, tensor)); - } +struct SizeOpArrayLowering : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; - mlir::Value result = rewriter.create( - loc, requestedResultTensorType, dynamicDimensions); + mlir::LogicalResult + matchAndRewrite(SizeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Value tensor = adaptor.getArray(); - for (int64_t dim = 0, rank = requestedResultTensorType.getRank(); - dim < rank; ++dim) { - mlir::Value index = rewriter.create( - loc, rewriter.getIndexAttr(dim)); + if (op.hasDimension()) { + return rewriter.notifyMatchFailure(op, "Index specified"); + } - mlir::Value size = rewriter.create( - loc, tensor, index); + mlir::Type requestedResultType = + getTypeConverter()->convertType(op.getResult().getType()); - if (size.getType() != requestedResultElementType) { - size = rewriter.create( - loc, requestedResultElementType, size); - } + auto requestedResultTensorType = + requestedResultType.cast(); - result = rewriter.create( - loc, size, result, index); - } + mlir::Type requestedResultElementType = + requestedResultTensorType.getElementType(); + + llvm::SmallVector dynamicDimensions; - rewriter.replaceOp(op, result); - return mlir::success(); + if (requestedResultTensorType.getDimSize(0) == mlir::ShapedType::kDynamic) { + dynamicDimensions.push_back( + rewriter.create(loc, tensor)); } - }; - struct FillOpLowering : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; + mlir::Value result = rewriter.create( + loc, requestedResultTensorType, dynamicDimensions); - mlir::LogicalResult matchAndRewrite( - FillOp op, - mlir::PatternRewriter& rewriter) const override - { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getValue()); + for (int64_t dim = 0, rank = requestedResultTensorType.getRank(); + dim < rank; ++dim) { + mlir::Value index = rewriter.create( + loc, rewriter.getIndexAttr(dim)); - return mlir::success(); + mlir::Value size = + rewriter.create(loc, tensor, index); + + if (size.getType() != requestedResultElementType) { + size = rewriter.create(loc, requestedResultElementType, size); + } + + result = + rewriter.create(loc, size, result, index); } - }; -} -mlir::LogicalResult BaseModelicaToTensorConversionPass::convertOperations() -{ + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + +struct FillOpLowering : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(FillOp op, mlir::PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.getValue()); + + return mlir::success(); + } +}; +} // namespace + +mlir::LogicalResult BaseModelicaToTensorConversionPass::convertOperations() { auto module = getOperation(); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); - target.addIllegalOp< - TensorFromElementsOp, - TensorBroadcastOp, - TensorViewOp, - TensorExtractOp, - TensorInsertOp, - TensorInsertSliceOp, - NDimsOp, - SizeOp, - FillOp>(); - - target.markUnknownOpDynamicallyLegal([](mlir::Operation* op) { - return true; - }); + target.addIllegalOp(); + + target.markUnknownOpDynamicallyLegal( + [](mlir::Operation *op) { return true; }); mlir::RewritePatternSet patterns(&getContext()); TypeConverter typeConverter; - populateBaseModelicaToTensorConversionPatterns( - patterns, &getContext(), typeConverter); + populateBaseModelicaToTensorConversionPatterns(patterns, &getContext(), + typeConverter); return applyPartialConversion(module, target, std::move(patterns)); } -namespace mlir -{ - void populateBaseModelicaToTensorConversionPatterns( - mlir::RewritePatternSet& patterns, - mlir::MLIRContext* context, - mlir::TypeConverter& typeConverter) - { - patterns.insert< - TensorFromElementsOpLowering, - TensorBroadcastOpLowering, - TensorViewOpLowering, - TensorExtractOpLowering, - TensorInsertOpLowering, - TensorInsertSliceOpLowering, - NDimsOpLowering, - SizeOpDimensionLowering, - SizeOpArrayLowering>(typeConverter, context); - - patterns.insert(context); - } +namespace mlir { +void populateBaseModelicaToTensorConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) { + patterns + .insert( + typeConverter, context); + + patterns.insert(context); +} - std::unique_ptr createBaseModelicaToTensorConversionPass() - { - return std::make_unique(); - } +std::unique_ptr createBaseModelicaToTensorConversionPass() { + return std::make_unique(); } +} // namespace mlir diff --git a/lib/Dialect/BaseModelica/IR/Ops.cpp b/lib/Dialect/BaseModelica/IR/Ops.cpp index f27e4147c..68ef777b0 100644 --- a/lib/Dialect/BaseModelica/IR/Ops.cpp +++ b/lib/Dialect/BaseModelica/IR/Ops.cpp @@ -17,26 +17,21 @@ using namespace ::mlir::bmodelica; // BaseModelica Dialect //===---------------------------------------------------------------------===// -namespace mlir::bmodelica -{ - void BaseModelicaDialect::registerOperations() - { - addOperations< +namespace mlir::bmodelica { +void BaseModelicaDialect::registerOperations() { + addOperations< #define GET_OP_LIST #include "marco/Dialect/BaseModelica/IR/BaseModelicaOps.cpp.inc" - >(); - } + >(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // BaseModelica Operations //===---------------------------------------------------------------------===// -static bool parseWrittenVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ - if (parser.parseKeyword("writtenVariables") || - parser.parseEqual()) { +static bool parseWrittenVars(mlir::OpAsmParser &parser, VariablesList &prop) { + if (parser.parseKeyword("writtenVariables") || parser.parseEqual()) { return true; } @@ -47,20 +42,14 @@ static bool parseWrittenVars( return false; } -static void printWrittenVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printWrittenVars(mlir::OpAsmPrinter &printer, mlir::Operation *op, + const VariablesList &prop) { printer << "writtenVariables = "; print(printer, prop); } -static bool parseReadVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ - if (parser.parseKeyword("readVariables") || - parser.parseEqual()) { +static bool parseReadVars(mlir::OpAsmParser &parser, VariablesList &prop) { + if (parser.parseKeyword("readVariables") || parser.parseEqual()) { return true; } @@ -71,18 +60,14 @@ static bool parseReadVars( return false; } -static void printReadVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printReadVars(mlir::OpAsmPrinter &printer, mlir::Operation *op, + const VariablesList &prop) { printer << "readVariables = "; print(printer, prop); } -static bool parseModelDerivativesMap( - mlir::OpAsmParser& parser, DerivativesMap& prop) -{ +static bool parseModelDerivativesMap(mlir::OpAsmParser &parser, + DerivativesMap &prop) { if (mlir::succeeded(parser.parseOptionalKeyword("der"))) { if (parser.parseEqual()) { return true; @@ -96,111 +81,92 @@ static bool parseModelDerivativesMap( return false; } -static void printModelDerivativesMap( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const DerivativesMap& prop) -{ +static void printModelDerivativesMap(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + const DerivativesMap &prop) { if (!prop.empty()) { printer << "der = "; print(printer, prop); } } -static bool parseAbstractEquationWrittenVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ +static bool parseAbstractEquationWrittenVars(mlir::OpAsmParser &parser, + VariablesList &prop) { return parseWrittenVars(parser, prop); } -static void printAbstractEquationWrittenVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printAbstractEquationWrittenVars(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + const VariablesList &prop) { return printWrittenVars(printer, op, prop); } -static bool parseAbstractEquationReadVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ +static bool parseAbstractEquationReadVars(mlir::OpAsmParser &parser, + VariablesList &prop) { return parseReadVars(parser, prop); } -static void printAbstractEquationReadVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printAbstractEquationReadVars(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + const VariablesList &prop) { return printReadVars(printer, op, prop); } -static bool parseScheduleBlockWrittenVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ +static bool parseScheduleBlockWrittenVars(mlir::OpAsmParser &parser, + VariablesList &prop) { return parseWrittenVars(parser, prop); } -static void printScheduleBlockWrittenVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printScheduleBlockWrittenVars(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + const VariablesList &prop) { return printWrittenVars(printer, op, prop); } -static bool parseScheduleBlockReadVars( - mlir::OpAsmParser& parser, VariablesList& prop) -{ +static bool parseScheduleBlockReadVars(mlir::OpAsmParser &parser, + VariablesList &prop) { return parseReadVars(parser, prop); } -static void printScheduleBlockReadVars( - mlir::OpAsmPrinter& printer, - mlir::Operation* op, - const VariablesList& prop) -{ +static void printScheduleBlockReadVars(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + const VariablesList &prop) { return printReadVars(printer, op, prop); } #define GET_OP_CLASSES #include "marco/Dialect/BaseModelica/IR/BaseModelicaOps.cpp.inc" -namespace -{ - template - std::optional getScalarAttributeValue(mlir::Attribute attribute) - { - if (isScalarIntegerLike(attribute)) { - return static_cast(getScalarIntegerLikeValue(attribute)); - } else if (isScalarFloatLike(attribute)) { - return static_cast(getScalarFloatLikeValue(attribute)); - } else { - return std::nullopt; - } +namespace { +template +std::optional getScalarAttributeValue(mlir::Attribute attribute) { + if (isScalarIntegerLike(attribute)) { + return static_cast(getScalarIntegerLikeValue(attribute)); + } else if (isScalarFloatLike(attribute)) { + return static_cast(getScalarFloatLikeValue(attribute)); + } else { + return std::nullopt; } +} - template - bool getScalarAttributesValues( - llvm::ArrayRef attributes, - llvm::SmallVectorImpl& result) - { - for (mlir::Attribute attribute : attributes) { - if (auto value = getScalarAttributeValue(attribute)) { - result.push_back(*value); - } else { - return false; - } +template +bool getScalarAttributesValues(llvm::ArrayRef attributes, + llvm::SmallVectorImpl &result) { + for (mlir::Attribute attribute : attributes) { + if (auto value = getScalarAttributeValue(attribute)) { + result.push_back(*value); + } else { + return false; } - - return true; } + + return true; } +} // namespace -static mlir::LogicalResult cleanEquationTemplates( - mlir::RewriterBase& rewriter, - llvm::ArrayRef templateOps) -{ +static mlir::LogicalResult +cleanEquationTemplates(mlir::RewriterBase &rewriter, + llvm::ArrayRef templateOps) { mlir::RewritePatternSet patterns(rewriter.getContext()); for (mlir::RegisteredOperationName registeredOp : @@ -211,7 +177,7 @@ static mlir::LogicalResult cleanEquationTemplates( mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns)); mlir::GreedyRewriteConfig config; - mlir::OpBuilder::Listener* listener = rewriter.getListener(); + mlir::OpBuilder::Listener *listener = rewriter.getListener(); mlir::RewriterBase::ForwardingListener forwardingListener(listener); if (listener != nullptr) { @@ -235,252 +201,228 @@ static mlir::LogicalResult cleanEquationTemplates( //===---------------------------------------------------------------------===// // RangeOp -namespace mlir::bmodelica -{ - mlir::LogicalResult RangeOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lowerBoundType = adaptor.getLowerBound().getType(); - mlir::Type upperBoundType = adaptor.getUpperBound().getType(); - mlir::Type stepType = adaptor.getStep().getType(); - - if (isScalar(lowerBoundType) && - isScalar(upperBoundType) && - isScalar(stepType)) { - mlir::Type resultType = - getMostGenericScalarType(lowerBoundType, upperBoundType); - - resultType = getMostGenericScalarType(resultType, stepType); - returnTypes.push_back(RangeType::get(context, resultType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult RangeOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lowerBoundType = adaptor.getLowerBound().getType(); + mlir::Type upperBoundType = adaptor.getUpperBound().getType(); + mlir::Type stepType = adaptor.getStep().getType(); + + if (isScalar(lowerBoundType) && isScalar(upperBoundType) && + isScalar(stepType)) { + mlir::Type resultType = + getMostGenericScalarType(lowerBoundType, upperBoundType); + + resultType = getMostGenericScalarType(resultType, stepType); + returnTypes.push_back(RangeType::get(context, resultType)); + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool RangeOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool RangeOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult RangeOp::fold(FoldAdaptor adaptor) - { - auto lowerBound = adaptor.getLowerBound(); - auto upperBound = adaptor.getUpperBound(); - auto step = adaptor.getStep(); + return true; +} - if (!lowerBound || !upperBound || !step) { - return {}; - } +mlir::OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { + auto lowerBound = adaptor.getLowerBound(); + auto upperBound = adaptor.getUpperBound(); + auto step = adaptor.getStep(); - if (isScalarIntegerLike(lowerBound) && - isScalarIntegerLike(upperBound) && - isScalarIntegerLike(step)) { - int64_t lowerBoundValue = getScalarIntegerLikeValue(lowerBound); - int64_t upperBoundValue = getScalarIntegerLikeValue(upperBound); - int64_t stepValue = getScalarIntegerLikeValue(step); + if (!lowerBound || !upperBound || !step) { + return {}; + } - return IntegerRangeAttr::get( - getContext(), getResult().getType(), - lowerBoundValue, upperBoundValue, stepValue); - } + if (isScalarIntegerLike(lowerBound) && isScalarIntegerLike(upperBound) && + isScalarIntegerLike(step)) { + int64_t lowerBoundValue = getScalarIntegerLikeValue(lowerBound); + int64_t upperBoundValue = getScalarIntegerLikeValue(upperBound); + int64_t stepValue = getScalarIntegerLikeValue(step); - if ((isScalarIntegerLike(lowerBound) || isScalarFloatLike(lowerBound)) && - (isScalarIntegerLike(upperBound) || isScalarFloatLike(upperBound)) && - (isScalarIntegerLike(step) || isScalarFloatLike(step))) { - double lowerBoundValue; - double upperBoundValue; - double stepValue; + return IntegerRangeAttr::get(getContext(), getResult().getType(), + lowerBoundValue, upperBoundValue, stepValue); + } - if (isScalarIntegerLike(lowerBound)) { - lowerBoundValue = - static_cast(getScalarIntegerLikeValue(lowerBound)); - } else { - lowerBoundValue = getScalarFloatLikeValue(lowerBound); - } + if ((isScalarIntegerLike(lowerBound) || isScalarFloatLike(lowerBound)) && + (isScalarIntegerLike(upperBound) || isScalarFloatLike(upperBound)) && + (isScalarIntegerLike(step) || isScalarFloatLike(step))) { + double lowerBoundValue; + double upperBoundValue; + double stepValue; - if (isScalarIntegerLike(upperBound)) { - upperBoundValue = - static_cast(getScalarIntegerLikeValue(upperBound)); - } else { - upperBoundValue = getScalarFloatLikeValue(upperBound); - } + if (isScalarIntegerLike(lowerBound)) { + lowerBoundValue = + static_cast(getScalarIntegerLikeValue(lowerBound)); + } else { + lowerBoundValue = getScalarFloatLikeValue(lowerBound); + } - if (isScalarIntegerLike(step)) { - stepValue = - static_cast(getScalarIntegerLikeValue(step)); - } else { - stepValue = getScalarFloatLikeValue(step); - } + if (isScalarIntegerLike(upperBound)) { + upperBoundValue = + static_cast(getScalarIntegerLikeValue(upperBound)); + } else { + upperBoundValue = getScalarFloatLikeValue(upperBound); + } - return RealRangeAttr::get( - getContext(), getResult().getType(), - llvm::APFloat(lowerBoundValue), - llvm::APFloat(upperBoundValue), - llvm::APFloat(stepValue)); + if (isScalarIntegerLike(step)) { + stepValue = static_cast(getScalarIntegerLikeValue(step)); + } else { + stepValue = getScalarFloatLikeValue(step); } - return {}; + return RealRangeAttr::get( + getContext(), getResult().getType(), llvm::APFloat(lowerBoundValue), + llvm::APFloat(upperBoundValue), llvm::APFloat(stepValue)); } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RangeBeginOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult RangeBeginOp::fold(FoldAdaptor adaptor) - { - auto range = adaptor.getRange(); - - if (!range) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult RangeBeginOp::fold(FoldAdaptor adaptor) { + auto range = adaptor.getRange(); - if (auto intRange = range.dyn_cast()) { - mlir::Type inductionType = - intRange.getType().cast().getInductionType(); + if (!range) { + return {}; + } - if (inductionType.isa()) { - return mlir::IntegerAttr::get( - mlir::IndexType::get(getContext()), intRange.getLowerBound()); - } else { - return IntegerAttr::get(getContext(), intRange.getLowerBound()); - } - } + if (auto intRange = range.dyn_cast()) { + mlir::Type inductionType = + intRange.getType().cast().getInductionType(); - if (auto realRange = range.dyn_cast()) { - return RealAttr::get( - getContext(), realRange.getLowerBound().convertToDouble()); + if (inductionType.isa()) { + return mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), + intRange.getLowerBound()); + } else { + return IntegerAttr::get(getContext(), intRange.getLowerBound()); } + } - return {}; + if (auto realRange = range.dyn_cast()) { + return RealAttr::get(getContext(), + realRange.getLowerBound().convertToDouble()); } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RangeEndOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult RangeEndOp::fold(FoldAdaptor adaptor) - { - auto range = adaptor.getRange(); - - if (!range) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult RangeEndOp::fold(FoldAdaptor adaptor) { + auto range = adaptor.getRange(); - if (auto intRange = range.dyn_cast()) { - mlir::Type inductionType = - intRange.getType().cast().getInductionType(); + if (!range) { + return {}; + } - if (inductionType.isa()) { - return mlir::IntegerAttr::get( - mlir::IndexType::get(getContext()), intRange.getUpperBound()); - } else { - return IntegerAttr::get(getContext(), intRange.getUpperBound()); - } - } + if (auto intRange = range.dyn_cast()) { + mlir::Type inductionType = + intRange.getType().cast().getInductionType(); - if (auto realRange = range.dyn_cast()) { - return RealAttr::get( - getContext(), realRange.getUpperBound().convertToDouble()); + if (inductionType.isa()) { + return mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), + intRange.getUpperBound()); + } else { + return IntegerAttr::get(getContext(), intRange.getUpperBound()); } + } - return {}; + if (auto realRange = range.dyn_cast()) { + return RealAttr::get(getContext(), + realRange.getUpperBound().convertToDouble()); } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RangeStepOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult RangeStepOp::fold(FoldAdaptor adaptor) - { - auto range = adaptor.getRange(); - - if (!range) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult RangeStepOp::fold(FoldAdaptor adaptor) { + auto range = adaptor.getRange(); - if (auto intRange = range.dyn_cast()) { - mlir::Type inductionType = - intRange.getType().cast().getInductionType(); + if (!range) { + return {}; + } - if (inductionType.isa()) { - return mlir::IntegerAttr::get( - mlir::IndexType::get(getContext()), intRange.getStep()); - } else { - return IntegerAttr::get(getContext(), intRange.getStep()); - } - } + if (auto intRange = range.dyn_cast()) { + mlir::Type inductionType = + intRange.getType().cast().getInductionType(); - if (auto realRange = range.dyn_cast()) { - return RealAttr::get( - getContext(), realRange.getStep().convertToDouble()); + if (inductionType.isa()) { + return mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), + intRange.getStep()); + } else { + return IntegerAttr::get(getContext(), intRange.getStep()); } + } - return {}; + if (auto realRange = range.dyn_cast()) { + return RealAttr::get(getContext(), realRange.getStep().convertToDouble()); } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RangeSizeOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult RangeSizeOp::fold(FoldAdaptor adaptor) - { - auto range = adaptor.getRange(); - - if (!range) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult RangeSizeOp::fold(FoldAdaptor adaptor) { + auto range = adaptor.getRange(); - if (auto intRange = range.dyn_cast()) { - int64_t beginValue = intRange.getLowerBound(); - int64_t endValue = intRange.getUpperBound(); - int64_t step = intRange.getStep(); - int64_t result = 1 + (endValue - beginValue) / step; + if (!range) { + return {}; + } - return mlir::IntegerAttr::get( - mlir::IndexType::get(getContext()), result); - } + if (auto intRange = range.dyn_cast()) { + int64_t beginValue = intRange.getLowerBound(); + int64_t endValue = intRange.getUpperBound(); + int64_t step = intRange.getStep(); + int64_t result = 1 + (endValue - beginValue) / step; - if (auto realRange = range.dyn_cast()) { - double beginValue = realRange.getLowerBound().convertToDouble(); - double endValue = realRange.getUpperBound().convertToDouble(); - double step = realRange.getStep().convertToDouble(); - double result = 1 + (endValue - beginValue) / step; + return mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), result); + } - return mlir::IntegerAttr::get( - mlir::IndexType::get(getContext()), - static_cast(result)); - } + if (auto realRange = range.dyn_cast()) { + double beginValue = realRange.getLowerBound().convertToDouble(); + double endValue = realRange.getUpperBound().convertToDouble(); + double step = realRange.getStep().convertToDouble(); + double result = 1 + (endValue - beginValue) / step; - return {}; + return mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), + static_cast(result)); } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Tensor operations @@ -489,361 +431,332 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TensorFromElementsOp -namespace mlir::bmodelica -{ - mlir::LogicalResult TensorFromElementsOp::verify() - { - if (!getResult().getType().hasStaticShape()) { - return emitOpError("the shape must be fixed"); - } - - int64_t tensorFlatSize = getResult().getType().getNumElements(); - size_t numOfValues = getValues().size(); +namespace mlir::bmodelica { +mlir::LogicalResult TensorFromElementsOp::verify() { + if (!getResult().getType().hasStaticShape()) { + return emitOpError("the shape must be fixed"); + } - if (tensorFlatSize != static_cast(numOfValues)) { - return emitOpError( - "incorrect number of values (expected " + - std::to_string(tensorFlatSize) + ", got " + - std::to_string(numOfValues) + ")"); - } + int64_t tensorFlatSize = getResult().getType().getNumElements(); + size_t numOfValues = getValues().size(); - return mlir::success(); + if (tensorFlatSize != static_cast(numOfValues)) { + return emitOpError("incorrect number of values (expected " + + std::to_string(tensorFlatSize) + ", got " + + std::to_string(numOfValues) + ")"); } - mlir::OpFoldResult TensorFromElementsOp::fold(FoldAdaptor adaptor) - { - if (llvm::all_of(adaptor.getOperands(), [](mlir::Attribute attr) { - return attr != nullptr; - })) { - mlir::TensorType tensorType = getResult().getType(); + return mlir::success(); +} - if (!tensorType.hasStaticShape()) { - return {}; - } +mlir::OpFoldResult TensorFromElementsOp::fold(FoldAdaptor adaptor) { + if (llvm::all_of(adaptor.getOperands(), + [](mlir::Attribute attr) { return attr != nullptr; })) { + mlir::TensorType tensorType = getResult().getType(); - mlir::Type elementType = tensorType.getElementType(); + if (!tensorType.hasStaticShape()) { + return {}; + } - if (elementType.isa()) { - llvm::SmallVector casted; + mlir::Type elementType = tensorType.getElementType(); - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseBooleanElementsAttr::get(tensorType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - if (elementType.isa()) { - llvm::SmallVector casted; + return DenseBooleanElementsAttr::get(tensorType, casted); + } - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseIntegerElementsAttr::get(tensorType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - if (elementType.isa()) { - llvm::SmallVector casted; + return DenseIntegerElementsAttr::get(tensorType, casted); + } - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseRealElementsAttr::get(tensorType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - } - return {}; + return DenseRealElementsAttr::get(tensorType, casted); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TensorViewOp -namespace -{ - struct InferTensorViewResultTypePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - TensorViewOp op, mlir::PatternRewriter& rewriter) const override - { - mlir::TensorType inferredResultType = TensorViewOp::inferResultType( - op.getSource().getType(), op.getSubscriptions()); - - if (inferredResultType != op.getResult().getType()) { - auto newOp = rewriter.create( - op.getLoc(), inferredResultType, op.getSource(), - op.getSubscriptions()); - - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp); - return mlir::success(); - } +namespace { +struct InferTensorViewResultTypePattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - return mlir::failure(); - } - }; + mlir::LogicalResult + matchAndRewrite(TensorViewOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::TensorType inferredResultType = TensorViewOp::inferResultType( + op.getSource().getType(), op.getSubscriptions()); - struct MergeTensorViewsPattern : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; + if (inferredResultType != op.getResult().getType()) { + auto newOp = + rewriter.create(op.getLoc(), inferredResultType, + op.getSource(), op.getSubscriptions()); - mlir::LogicalResult matchAndRewrite( - TensorViewOp op, mlir::PatternRewriter& rewriter) const override - { - auto viewOp = op.getSource().getDefiningOp(); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } - if (!viewOp) { - return mlir::failure(); - } + return mlir::failure(); + } +}; - llvm::SmallVector viewOps; +struct MergeTensorViewsPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - while (viewOp) { - viewOps.push_back(viewOp); - viewOp = viewOp.getSource().getDefiningOp(); - } + mlir::LogicalResult + matchAndRewrite(TensorViewOp op, + mlir::PatternRewriter &rewriter) const override { + auto viewOp = op.getSource().getDefiningOp(); - assert(!viewOps.empty()); - mlir::Value source = viewOps.back().getSource(); - llvm::SmallVector subscriptions; + if (!viewOp) { + return mlir::failure(); + } - while (!viewOps.empty()) { - TensorViewOp current = viewOps.pop_back_val(); - subscriptions.append(current.getSubscriptions().begin(), - current.getSubscriptions().end()); - } + llvm::SmallVector viewOps; - subscriptions.append(op.getSubscriptions().begin(), - op.getSubscriptions().end()); + while (viewOp) { + viewOps.push_back(viewOp); + viewOp = viewOp.getSource().getDefiningOp(); + } - rewriter.replaceOpWithNewOp(op, source, subscriptions); - return mlir::success(); + assert(!viewOps.empty()); + mlir::Value source = viewOps.back().getSource(); + llvm::SmallVector subscriptions; + + while (!viewOps.empty()) { + TensorViewOp current = viewOps.pop_back_val(); + subscriptions.append(current.getSubscriptions().begin(), + current.getSubscriptions().end()); } - }; -} -namespace mlir::bmodelica -{ - void TensorViewOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - mlir::Value source, - mlir::ValueRange subscriptions) - { - mlir::TensorType resultType = inferResultType( - source.getType().cast(), subscriptions); + subscriptions.append(op.getSubscriptions().begin(), + op.getSubscriptions().end()); - build(builder, state, resultType, source, subscriptions); + rewriter.replaceOpWithNewOp(op, source, subscriptions); + return mlir::success(); } +}; +} // namespace - mlir::LogicalResult TensorViewOp::verify() - { - mlir::TensorType sourceType = getSource().getType(); - mlir::TensorType resultType = getResult().getType(); - - mlir::TensorType expectedResultType = - inferResultType(sourceType, getSubscriptions()); +namespace mlir::bmodelica { +void TensorViewOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value source, mlir::ValueRange subscriptions) { + mlir::TensorType resultType = + inferResultType(source.getType().cast(), subscriptions); - if (resultType.getRank() != expectedResultType.getRank()) { - return emitOpError() << "incompatible result rank"; - } + build(builder, state, resultType, source, subscriptions); +} - for (int64_t i = 0, e = resultType.getRank(); i < e; ++i) { - int64_t actualDimSize = resultType.getDimSize(i); - int64_t expectedDimSize = expectedResultType.getDimSize(i); +mlir::LogicalResult TensorViewOp::verify() { + mlir::TensorType sourceType = getSource().getType(); + mlir::TensorType resultType = getResult().getType(); - if (actualDimSize != mlir::ShapedType::kDynamic && - actualDimSize != expectedDimSize) { - return emitOpError() << "incompatible size for dimension " << i - << " (expected " << expectedDimSize << ", got " - << actualDimSize << ")"; - } - } + mlir::TensorType expectedResultType = + inferResultType(sourceType, getSubscriptions()); - return mlir::success(); + if (resultType.getRank() != expectedResultType.getRank()) { + return emitOpError() << "incompatible result rank"; } - void TensorViewOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add< - InferTensorViewResultTypePattern, - MergeTensorViewsPattern>(context); + for (int64_t i = 0, e = resultType.getRank(); i < e; ++i) { + int64_t actualDimSize = resultType.getDimSize(i); + int64_t expectedDimSize = expectedResultType.getDimSize(i); + + if (actualDimSize != mlir::ShapedType::kDynamic && + actualDimSize != expectedDimSize) { + return emitOpError() << "incompatible size for dimension " << i + << " (expected " << expectedDimSize << ", got " + << actualDimSize << ")"; + } } - mlir::TensorType TensorViewOp::inferResultType( - mlir::TensorType source, mlir::ValueRange indices) - { - llvm::SmallVector shape; - size_t numOfSubscriptions = indices.size(); + return mlir::success(); +} + +void TensorViewOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add( + context); +} - for (size_t i = 0; i < numOfSubscriptions; ++i) { - mlir::Value index = indices[i]; +mlir::TensorType TensorViewOp::inferResultType(mlir::TensorType source, + mlir::ValueRange indices) { + llvm::SmallVector shape; + size_t numOfSubscriptions = indices.size(); - if (index.getType().isa()) { - int64_t dimension = mlir::ShapedType::kDynamic; + for (size_t i = 0; i < numOfSubscriptions; ++i) { + mlir::Value index = indices[i]; - if (auto constantOp = index.getDefiningOp()) { - auto indexAttr = constantOp.getValue(); + if (index.getType().isa()) { + int64_t dimension = mlir::ShapedType::kDynamic; - if (auto rangeAttr = mlir::dyn_cast(indexAttr)) { - dimension = rangeAttr.getNumOfElements(); - } - } + if (auto constantOp = index.getDefiningOp()) { + auto indexAttr = constantOp.getValue(); - shape.push_back(dimension); + if (auto rangeAttr = mlir::dyn_cast(indexAttr)) { + dimension = rangeAttr.getNumOfElements(); + } } - } - for (int64_t dimension : - source.getShape().drop_front(numOfSubscriptions)) { shape.push_back(dimension); } + } - return source.clone(shape); + for (int64_t dimension : source.getShape().drop_front(numOfSubscriptions)) { + shape.push_back(dimension); } + + return source.clone(shape); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TensorExtractOp -namespace -{ - struct MergeTensorViewIntoTensorExtractPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - TensorExtractOp op, mlir::PatternRewriter& rewriter) const override - { - auto viewOp = op.getTensor().getDefiningOp(); - - if (!viewOp) { - return mlir::failure(); - } - - llvm::SmallVector viewOps; +namespace { +struct MergeTensorViewIntoTensorExtractPattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - while (viewOp) { - viewOps.push_back(viewOp); - viewOp = viewOp.getSource().getDefiningOp(); - } + mlir::LogicalResult + matchAndRewrite(TensorExtractOp op, + mlir::PatternRewriter &rewriter) const override { + auto viewOp = op.getTensor().getDefiningOp(); - assert(!viewOps.empty()); - mlir::Value source = viewOps.back().getSource(); - llvm::SmallVector subscriptions; + if (!viewOp) { + return mlir::failure(); + } - while (!viewOps.empty()) { - TensorViewOp current = viewOps.pop_back_val(); - subscriptions.append(current.getSubscriptions().begin(), - current.getSubscriptions().end()); - } + llvm::SmallVector viewOps; - subscriptions.append(op.getIndices().begin(), op.getIndices().end()); - rewriter.replaceOpWithNewOp(op, source, subscriptions); - return mlir::success(); + while (viewOp) { + viewOps.push_back(viewOp); + viewOp = viewOp.getSource().getDefiningOp(); } - }; -} -namespace mlir::bmodelica -{ - mlir::ParseResult TensorExtractOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - mlir::OpAsmParser::UnresolvedOperand array; - mlir::Type tensorType; - llvm::SmallVector indices; - llvm::SmallVector indicesTypes; - - if (parser.parseOperand(array) || - parser.parseOperandList( - indices, mlir::OpAsmParser::Delimiter::Square) || - parser.parseColonType(tensorType) || - parser.resolveOperand(array, tensorType, result.operands)) { - return mlir::failure(); + assert(!viewOps.empty()); + mlir::Value source = viewOps.back().getSource(); + llvm::SmallVector subscriptions; + + while (!viewOps.empty()) { + TensorViewOp current = viewOps.pop_back_val(); + subscriptions.append(current.getSubscriptions().begin(), + current.getSubscriptions().end()); } - indicesTypes.resize( - indices.size(), - mlir::IndexType::get(result.getContext())); + subscriptions.append(op.getIndices().begin(), op.getIndices().end()); + rewriter.replaceOpWithNewOp(op, source, subscriptions); + return mlir::success(); + } +}; +} // namespace + +namespace mlir::bmodelica { +mlir::ParseResult TensorExtractOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); + mlir::OpAsmParser::UnresolvedOperand array; + mlir::Type tensorType; + llvm::SmallVector indices; + llvm::SmallVector indicesTypes; + + if (parser.parseOperand(array) || + parser.parseOperandList(indices, mlir::OpAsmParser::Delimiter::Square) || + parser.parseColonType(tensorType) || + parser.resolveOperand(array, tensorType, result.operands)) { + return mlir::failure(); + } - size_t i = 0; + indicesTypes.resize(indices.size(), + mlir::IndexType::get(result.getContext())); - while (mlir::succeeded(parser.parseOptionalComma())) { - if (parser.parseType(indicesTypes[i++])) { - return mlir::failure(); - } - } + size_t i = 0; - if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + while (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseType(indicesTypes[i++])) { return mlir::failure(); } + } - result.addTypes(tensorType.cast().getElementType()); - return mlir::success(); + if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + return mlir::failure(); } - void TensorExtractOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << getTensor() << "[" << getIndices() << "]"; - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : " << getTensor().getType(); + result.addTypes(tensorType.cast().getElementType()); + return mlir::success(); +} + +void TensorExtractOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << getTensor() << "[" << getIndices() << "]"; + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : " << getTensor().getType(); - if (!llvm::all_of(getIndices(), [](mlir::Value index) { - return index.getType().isa(); - })) { - for (mlir::Value index : getIndices()) { - printer << ", " << index.getType(); - } + if (!llvm::all_of(getIndices(), [](mlir::Value index) { + return index.getType().isa(); + })) { + for (mlir::Value index : getIndices()) { + printer << ", " << index.getType(); } } +} - mlir::LogicalResult TensorExtractOp::verify() - { - size_t indicesAmount = getIndices().size(); - int64_t rank = getTensor().getType().getRank(); +mlir::LogicalResult TensorExtractOp::verify() { + size_t indicesAmount = getIndices().size(); + int64_t rank = getTensor().getType().getRank(); - if (rank != static_cast(indicesAmount)) { - return emitOpError() - << "incorrect number of indices (expected " << rank - << ", got " << indicesAmount << ")"; - } + if (rank != static_cast(indicesAmount)) { + return emitOpError() << "incorrect number of indices (expected " << rank + << ", got " << indicesAmount << ")"; + } - for (size_t i = 0; i < indicesAmount; ++i) { - if (auto constantOp = getIndices()[i].getDefiningOp()) { - if (auto index = getScalarAttributeValue( - constantOp.getValue())) { - if (*index < 0) { - return emitOpError() << "invalid index (" << *index << ")"; - } + for (size_t i = 0; i < indicesAmount; ++i) { + if (auto constantOp = getIndices()[i].getDefiningOp()) { + if (auto index = + getScalarAttributeValue(constantOp.getValue())) { + if (*index < 0) { + return emitOpError() << "invalid index (" << *index << ")"; + } - if (int64_t dimSize = getTensor().getType().getDimSize(i); - *index >= dimSize) { - return emitOpError() - << "out of bounds access (index = " << *index - << ", dimension = " << dimSize << ")"; - } + if (int64_t dimSize = getTensor().getType().getDimSize(i); + *index >= dimSize) { + return emitOpError() << "out of bounds access (index = " << *index + << ", dimension = " << dimSize << ")"; } } } - - return mlir::success(); } - void TensorExtractOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); - } + return mlir::success(); +} + +void TensorExtractOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Array operations @@ -852,684 +765,593 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AllocaOp -namespace mlir::bmodelica -{ - mlir::LogicalResult AllocaOp::verify() - { - int64_t dynamicDimensionsAmount = getArrayType().getNumDynamicDims(); - size_t valuesAmount = getDynamicSizes().size(); - - if (dynamicDimensionsAmount != static_cast(valuesAmount)) { - return emitOpError( - "incorrect number of values for dynamic dimensions (expected " + - std::to_string(dynamicDimensionsAmount) + ", got " + - std::to_string(valuesAmount) + ")"); - } +namespace mlir::bmodelica { +mlir::LogicalResult AllocaOp::verify() { + int64_t dynamicDimensionsAmount = getArrayType().getNumDynamicDims(); + size_t valuesAmount = getDynamicSizes().size(); - return mlir::success(); + if (dynamicDimensionsAmount != static_cast(valuesAmount)) { + return emitOpError( + "incorrect number of values for dynamic dimensions (expected " + + std::to_string(dynamicDimensionsAmount) + ", got " + + std::to_string(valuesAmount) + ")"); } - void AllocaOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - if (auto arrayType = getResult().getType().dyn_cast()) { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::AutomaticAllocationScopeResource::get()); - } + return mlir::success(); +} + +void AllocaOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + if (auto arrayType = getResult().getType().dyn_cast()) { + effects.emplace_back( + mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::AutomaticAllocationScopeResource::get()); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AllocOp -namespace mlir::bmodelica -{ - mlir::LogicalResult AllocOp::verify() - { - int64_t dynamicDimensionsAmount = getArrayType().getNumDynamicDims(); - size_t valuesAmount = getDynamicSizes().size(); - - if (dynamicDimensionsAmount != static_cast(valuesAmount)) { - return emitOpError( - "incorrect number of values for dynamic dimensions (expected " + - std::to_string(dynamicDimensionsAmount) + ", got " + - std::to_string(valuesAmount) + ")"); - } +namespace mlir::bmodelica { +mlir::LogicalResult AllocOp::verify() { + int64_t dynamicDimensionsAmount = getArrayType().getNumDynamicDims(); + size_t valuesAmount = getDynamicSizes().size(); - return mlir::success(); + if (dynamicDimensionsAmount != static_cast(valuesAmount)) { + return emitOpError( + "incorrect number of values for dynamic dimensions (expected " + + std::to_string(dynamicDimensionsAmount) + ", got " + + std::to_string(valuesAmount) + ")"); } - void AllocOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - if (auto arrayType = getResult().getType().dyn_cast()) { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); - } + return mlir::success(); +} + +void AllocOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + if (auto arrayType = getResult().getType().dyn_cast()) { + effects.emplace_back(mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ArrayFromElementsOp -namespace mlir::bmodelica -{ - mlir::LogicalResult ArrayFromElementsOp::verify() - { - if (!getArrayType().hasStaticShape()) { - return emitOpError("the shape must be fixed"); - } +namespace mlir::bmodelica { +mlir::LogicalResult ArrayFromElementsOp::verify() { + if (!getArrayType().hasStaticShape()) { + return emitOpError("the shape must be fixed"); + } - int64_t arrayFlatSize = getArrayType().getNumElements(); - size_t numOfValues = getValues().size(); + int64_t arrayFlatSize = getArrayType().getNumElements(); + size_t numOfValues = getValues().size(); - if (arrayFlatSize != static_cast(numOfValues)) { - return emitOpError( - "incorrect number of values (expected " + - std::to_string(arrayFlatSize) + ", got " + - std::to_string(numOfValues) + ")"); - } + if (arrayFlatSize != static_cast(numOfValues)) { + return emitOpError("incorrect number of values (expected " + + std::to_string(arrayFlatSize) + ", got " + + std::to_string(numOfValues) + ")"); + } - return mlir::success(); - } - - mlir::OpFoldResult ArrayFromElementsOp::fold(FoldAdaptor adaptor) - { - if (llvm::all_of(adaptor.getOperands(), [](mlir::Attribute attr) { - return attr != nullptr; - })) { - ArrayType arrayType = getArrayType(); + return mlir::success(); +} - if (!arrayType.hasStaticShape()) { - return {}; - } +mlir::OpFoldResult ArrayFromElementsOp::fold(FoldAdaptor adaptor) { + if (llvm::all_of(adaptor.getOperands(), + [](mlir::Attribute attr) { return attr != nullptr; })) { + ArrayType arrayType = getArrayType(); - mlir::Type elementType = arrayType.getElementType(); + if (!arrayType.hasStaticShape()) { + return {}; + } - if (elementType.isa()) { - llvm::SmallVector casted; + mlir::Type elementType = arrayType.getElementType(); - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseBooleanElementsAttr::get(arrayType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - if (elementType.isa()) { - llvm::SmallVector casted; + return DenseBooleanElementsAttr::get(arrayType, casted); + } - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseIntegerElementsAttr::get(arrayType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - if (elementType.isa()) { - llvm::SmallVector casted; + return DenseIntegerElementsAttr::get(arrayType, casted); + } - if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { - return {}; - } + if (elementType.isa()) { + llvm::SmallVector casted; - return DenseRealElementsAttr::get(arrayType, casted); + if (!getScalarAttributesValues(adaptor.getOperands(), casted)) { + return {}; } - } - return {}; + return DenseRealElementsAttr::get(arrayType, casted); + } } - void ArrayFromElementsOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); + return {}; +} - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); - } +void ArrayFromElementsOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); + + effects.emplace_back(mlir::MemoryEffects::Write::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ArrayBroadcastOp -namespace mlir::bmodelica -{ - void ArrayBroadcastOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); +namespace mlir::bmodelica { +void ArrayBroadcastOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); - } + effects.emplace_back(mlir::MemoryEffects::Write::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // FreeOp -namespace mlir::bmodelica -{ - void FreeOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Free::get(), - getArray(), - mlir::SideEffects::DefaultResource::get()); - } +namespace mlir::bmodelica { +void FreeOp::getEffects(mlir::SmallVectorImpl> &effects) { + effects.emplace_back(mlir::MemoryEffects::Free::get(), getArray(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DimOp -namespace -{ - struct DimOpStaticDimensionPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - DimOp op, mlir::PatternRewriter& rewriter) const override - { - auto constantOp = op.getDimension().getDefiningOp(); - - if (!constantOp) { - return mlir::failure(); - } +namespace { +struct DimOpStaticDimensionPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - ArrayType arrayType = op.getArray().getType(); + mlir::LogicalResult + matchAndRewrite(DimOp op, mlir::PatternRewriter &rewriter) const override { + auto constantOp = op.getDimension().getDefiningOp(); - int64_t dimSize = arrayType.getDimSize( - constantOp.getValue().cast().getInt()); + if (!constantOp) { + return mlir::failure(); + } - if (dimSize == ArrayType::kDynamic) { - return mlir::failure(); - } + ArrayType arrayType = op.getArray().getType(); - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexAttr(dimSize)); + int64_t dimSize = arrayType.getDimSize( + constantOp.getValue().cast().getInt()); - return mlir::success(); + if (dimSize == ArrayType::kDynamic) { + return mlir::failure(); } - }; -} -namespace mlir::bmodelica -{ - void DimOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexAttr(dimSize)); + + return mlir::success(); } +}; +} // namespace + +namespace mlir::bmodelica { +void DimOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // LoadOp -namespace -{ - struct MergeSubscriptionsIntoLoadPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - LoadOp op, mlir::PatternRewriter& rewriter) const override - { - auto subscriptionOp = op.getArray().getDefiningOp(); - - if (!subscriptionOp) { - return mlir::failure(); - } - - llvm::SmallVector subscriptionOps; +namespace { +struct MergeSubscriptionsIntoLoadPattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - while (subscriptionOp) { - subscriptionOps.push_back(subscriptionOp); + mlir::LogicalResult + matchAndRewrite(LoadOp op, mlir::PatternRewriter &rewriter) const override { + auto subscriptionOp = op.getArray().getDefiningOp(); - subscriptionOp = - subscriptionOp.getSource().getDefiningOp(); - } + if (!subscriptionOp) { + return mlir::failure(); + } - assert(!subscriptionOps.empty()); - mlir::Value source = subscriptionOps.back().getSource(); - llvm::SmallVector indices; + llvm::SmallVector subscriptionOps; - while (!subscriptionOps.empty()) { - SubscriptionOp current = subscriptionOps.pop_back_val(); - indices.append(current.getIndices().begin(), - current.getIndices().end()); - } + while (subscriptionOp) { + subscriptionOps.push_back(subscriptionOp); - indices.append(op.getIndices().begin(), op.getIndices().end()); - rewriter.replaceOpWithNewOp(op, source, indices); - return mlir::success(); + subscriptionOp = + subscriptionOp.getSource().getDefiningOp(); } - }; -} -namespace mlir::bmodelica -{ - mlir::ParseResult LoadOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - mlir::OpAsmParser::UnresolvedOperand array; - mlir::Type arrayType; - llvm::SmallVector indices; - llvm::SmallVector indicesTypes; - - if (parser.parseOperand(array) || - parser.parseOperandList( - indices, mlir::OpAsmParser::Delimiter::Square) || - parser.parseColonType(arrayType) || - parser.resolveOperand(array, arrayType, result.operands)) { - return mlir::failure(); + assert(!subscriptionOps.empty()); + mlir::Value source = subscriptionOps.back().getSource(); + llvm::SmallVector indices; + + while (!subscriptionOps.empty()) { + SubscriptionOp current = subscriptionOps.pop_back_val(); + indices.append(current.getIndices().begin(), current.getIndices().end()); } - indicesTypes.resize( - indices.size(), - mlir::IndexType::get(result.getContext())); + indices.append(op.getIndices().begin(), op.getIndices().end()); + rewriter.replaceOpWithNewOp(op, source, indices); + return mlir::success(); + } +}; +} // namespace + +namespace mlir::bmodelica { +mlir::ParseResult LoadOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); + mlir::OpAsmParser::UnresolvedOperand array; + mlir::Type arrayType; + llvm::SmallVector indices; + llvm::SmallVector indicesTypes; + + if (parser.parseOperand(array) || + parser.parseOperandList(indices, mlir::OpAsmParser::Delimiter::Square) || + parser.parseColonType(arrayType) || + parser.resolveOperand(array, arrayType, result.operands)) { + return mlir::failure(); + } - size_t i = 0; + indicesTypes.resize(indices.size(), + mlir::IndexType::get(result.getContext())); - while (mlir::succeeded(parser.parseOptionalComma())) { - if (parser.parseType(indicesTypes[i++])) { - return mlir::failure(); - } - } + size_t i = 0; - if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + while (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseType(indicesTypes[i++])) { return mlir::failure(); } + } - result.addTypes(arrayType.cast().getElementType()); - return mlir::success(); + if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + return mlir::failure(); } - void LoadOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << getArray() << "[" << getIndices() << "]"; - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : " << getArray().getType(); + result.addTypes(arrayType.cast().getElementType()); + return mlir::success(); +} - if (!llvm::all_of(getIndices(), [](mlir::Value index) { - return index.getType().isa(); - })) { - for (mlir::Value index : getIndices()) { - printer << ", " << index.getType(); - } +void LoadOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << getArray() << "[" << getIndices() << "]"; + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : " << getArray().getType(); + + if (!llvm::all_of(getIndices(), [](mlir::Value index) { + return index.getType().isa(); + })) { + for (mlir::Value index : getIndices()) { + printer << ", " << index.getType(); } } +} - mlir::LogicalResult LoadOp::verify() - { - size_t indicesAmount = getIndices().size(); - int64_t rank = getArrayType().getRank(); +mlir::LogicalResult LoadOp::verify() { + size_t indicesAmount = getIndices().size(); + int64_t rank = getArrayType().getRank(); - if (rank != static_cast(indicesAmount)) { - return emitOpError() - << "incorrect number of indices (expected " << rank - << ", got " << indicesAmount << ")"; - } + if (rank != static_cast(indicesAmount)) { + return emitOpError() << "incorrect number of indices (expected " << rank + << ", got " << indicesAmount << ")"; + } - for (size_t i = 0; i < indicesAmount; ++i) { - if (auto constantOp = getIndices()[i].getDefiningOp()) { - if (auto index = getScalarAttributeValue( - constantOp.getValue())) { - if (*index < 0) { - return emitOpError() << "invalid index (" << *index << ")"; - } + for (size_t i = 0; i < indicesAmount; ++i) { + if (auto constantOp = getIndices()[i].getDefiningOp()) { + if (auto index = + getScalarAttributeValue(constantOp.getValue())) { + if (*index < 0) { + return emitOpError() << "invalid index (" << *index << ")"; + } - if (int64_t dimSize = getArrayType().getDimSize(i); - *index >= dimSize) { - return emitOpError() - << "out of bounds access (index = " << *index - << ", dimension = " << dimSize << ")"; - } + if (int64_t dimSize = getArrayType().getDimSize(i); *index >= dimSize) { + return emitOpError() << "out of bounds access (index = " << *index + << ", dimension = " << dimSize << ")"; } } } - - return mlir::success(); } - void LoadOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); - } + return mlir::success(); +} - void LoadOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Read::get(), - getArray(), - mlir::SideEffects::DefaultResource::get()); - } +void LoadOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} + +void LoadOp::getEffects(mlir::SmallVectorImpl> &effects) { + effects.emplace_back(mlir::MemoryEffects::Read::get(), getArray(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // StoreOp -namespace mlir::bmodelica -{ - mlir::ParseResult StoreOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - mlir::OpAsmParser::UnresolvedOperand array; - mlir::Type arrayType; - mlir::OpAsmParser::UnresolvedOperand value; - llvm::SmallVector indices; - llvm::SmallVector indicesTypes; - - if (parser.parseOperand(array) || - parser.parseOperandList( - indices, mlir::OpAsmParser::Delimiter::Square) || - parser.parseComma() || - parser.parseOperand(value) || - parser.parseColonType(arrayType) || - parser.resolveOperand(value, arrayType.cast().getElementType(), result.operands) || - parser.resolveOperand(array, arrayType, result.operands)) { - return mlir::failure(); - } - - indicesTypes.resize( - indices.size(), - mlir::IndexType::get(result.getContext())); +namespace mlir::bmodelica { +mlir::ParseResult StoreOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); + mlir::OpAsmParser::UnresolvedOperand array; + mlir::Type arrayType; + mlir::OpAsmParser::UnresolvedOperand value; + llvm::SmallVector indices; + llvm::SmallVector indicesTypes; + + if (parser.parseOperand(array) || + parser.parseOperandList(indices, mlir::OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseOperand(value) || + parser.parseColonType(arrayType) || + parser.resolveOperand(value, + arrayType.cast().getElementType(), + result.operands) || + parser.resolveOperand(array, arrayType, result.operands)) { + return mlir::failure(); + } - size_t i = 0; + indicesTypes.resize(indices.size(), + mlir::IndexType::get(result.getContext())); - while (mlir::succeeded(parser.parseOptionalComma())) { - if (parser.parseType(indicesTypes[i++])) { - return mlir::failure(); - } - } + size_t i = 0; - if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + while (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseType(indicesTypes[i++])) { return mlir::failure(); } + } - return mlir::success(); + if (parser.resolveOperands(indices, indicesTypes, loc, result.operands)) { + return mlir::failure(); } - void StoreOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << getArray() << "[" << getIndices() << "]" - << ", " << getValue(); + return mlir::success(); +} - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : " << getArray().getType(); +void StoreOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << getArray() << "[" << getIndices() << "]" + << ", " << getValue(); - if (!llvm::all_of(getIndices(), [](mlir::Value index) { - return index.getType().isa(); - })) { - for (mlir::Value index : getIndices()) { - printer << ", " << index.getType(); - } + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : " << getArray().getType(); + + if (!llvm::all_of(getIndices(), [](mlir::Value index) { + return index.getType().isa(); + })) { + for (mlir::Value index : getIndices()) { + printer << ", " << index.getType(); } } +} - mlir::LogicalResult StoreOp::verify() - { - size_t indicesAmount = getIndices().size(); - int64_t rank = getArrayType().getRank(); +mlir::LogicalResult StoreOp::verify() { + size_t indicesAmount = getIndices().size(); + int64_t rank = getArrayType().getRank(); - if (rank != static_cast(indicesAmount)) { - return emitOpError() - << "incorrect number of indices (expected " << rank - << ", got " << indicesAmount << ")"; - } + if (rank != static_cast(indicesAmount)) { + return emitOpError() << "incorrect number of indices (expected " << rank + << ", got " << indicesAmount << ")"; + } - for (size_t i = 0; i < indicesAmount; ++i) { - if (auto constantOp = getIndices()[i].getDefiningOp()) { - if (auto index = getScalarAttributeValue( - mlir::cast(constantOp.getValue()))) { - if (*index < 0) { - return emitOpError() << "invalid index (" << *index << ")"; - } + for (size_t i = 0; i < indicesAmount; ++i) { + if (auto constantOp = getIndices()[i].getDefiningOp()) { + if (auto index = getScalarAttributeValue( + mlir::cast(constantOp.getValue()))) { + if (*index < 0) { + return emitOpError() << "invalid index (" << *index << ")"; + } - if (int64_t dimSize = getArrayType().getDimSize(i); - *index >= dimSize) { - return emitOpError() - << "out of bounds access (index = " << *index - << ", dimension = " << dimSize << ")"; - } + if (int64_t dimSize = getArrayType().getDimSize(i); *index >= dimSize) { + return emitOpError() << "out of bounds access (index = " << *index + << ", dimension = " << dimSize << ")"; } } } - - return mlir::success(); } - void StoreOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getArray(), - mlir::SideEffects::DefaultResource::get()); - } + return mlir::success(); +} + +void StoreOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Write::get(), getArray(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SubscriptionOp -namespace -{ - struct InferSubscriptionResultTypePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - SubscriptionOp op, mlir::PatternRewriter& rewriter) const override - { - ArrayType inferredResultType = SubscriptionOp::inferResultType( - op.getSource().getType(), op.getIndices()); - - if (inferredResultType != op.getResultArrayType()) { - auto newOp = rewriter.create( - op.getLoc(), inferredResultType, op.getSource(), op.getIndices()); - - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp); - return mlir::success(); - } +namespace { +struct InferSubscriptionResultTypePattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - return mlir::failure(); + mlir::LogicalResult + matchAndRewrite(SubscriptionOp op, + mlir::PatternRewriter &rewriter) const override { + ArrayType inferredResultType = SubscriptionOp::inferResultType( + op.getSource().getType(), op.getIndices()); + + if (inferredResultType != op.getResultArrayType()) { + auto newOp = rewriter.create( + op.getLoc(), inferredResultType, op.getSource(), op.getIndices()); + + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp); + return mlir::success(); } - }; - struct MergeSubscriptionsPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; + return mlir::failure(); + } +}; - mlir::LogicalResult matchAndRewrite( - SubscriptionOp op, mlir::PatternRewriter& rewriter) const override - { - auto subscriptionOp = op.getSource().getDefiningOp(); +struct MergeSubscriptionsPattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - if (!subscriptionOp) { - return mlir::failure(); - } + mlir::LogicalResult + matchAndRewrite(SubscriptionOp op, + mlir::PatternRewriter &rewriter) const override { + auto subscriptionOp = op.getSource().getDefiningOp(); - llvm::SmallVector subscriptionOps; + if (!subscriptionOp) { + return mlir::failure(); + } - while (subscriptionOp) { - subscriptionOps.push_back(subscriptionOp); + llvm::SmallVector subscriptionOps; - subscriptionOp = - subscriptionOp.getSource().getDefiningOp(); - } + while (subscriptionOp) { + subscriptionOps.push_back(subscriptionOp); - assert(!subscriptionOps.empty()); - mlir::Value source = subscriptionOps.back().getSource(); - llvm::SmallVector indices; + subscriptionOp = + subscriptionOp.getSource().getDefiningOp(); + } - while (!subscriptionOps.empty()) { - SubscriptionOp current = subscriptionOps.pop_back_val(); - indices.append(current.getIndices().begin(), - current.getIndices().end()); - } + assert(!subscriptionOps.empty()); + mlir::Value source = subscriptionOps.back().getSource(); + llvm::SmallVector indices; - indices.append(op.getIndices().begin(), op.getIndices().end()); - rewriter.replaceOpWithNewOp(op, source, indices); - return mlir::success(); + while (!subscriptionOps.empty()) { + SubscriptionOp current = subscriptionOps.pop_back_val(); + indices.append(current.getIndices().begin(), current.getIndices().end()); } - }; -} -namespace mlir::bmodelica -{ - void SubscriptionOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - mlir::Value source, - mlir::ValueRange indices) - { - build(builder, state, - inferResultType(source.getType().cast(), indices), - source, indices); + indices.append(op.getIndices().begin(), op.getIndices().end()); + rewriter.replaceOpWithNewOp(op, source, indices); + return mlir::success(); } +}; +} // namespace - mlir::LogicalResult SubscriptionOp::verify() - { - ArrayType sourceType = getSourceArrayType(); - ArrayType resultType = getResultArrayType(); - ArrayType expectedResultType = inferResultType(sourceType, getIndices()); +namespace mlir::bmodelica { +void SubscriptionOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, mlir::Value source, + mlir::ValueRange indices) { + build(builder, state, + inferResultType(source.getType().cast(), indices), source, + indices); +} - if (resultType.getRank() != expectedResultType.getRank()) { - return emitOpError() << "incompatible result rank"; - } +mlir::LogicalResult SubscriptionOp::verify() { + ArrayType sourceType = getSourceArrayType(); + ArrayType resultType = getResultArrayType(); + ArrayType expectedResultType = inferResultType(sourceType, getIndices()); - for (int64_t i = 0, e = resultType.getRank(); i < e; ++i) { - int64_t actualDimSize = resultType.getDimSize(i); - int64_t expectedDimSize = expectedResultType.getDimSize(i); + if (resultType.getRank() != expectedResultType.getRank()) { + return emitOpError() << "incompatible result rank"; + } - if (actualDimSize != ArrayType::kDynamic && - actualDimSize != expectedDimSize) { - return emitOpError() << "incompatible size for dimension " << i - << " (expected " << expectedDimSize << ", got " - << actualDimSize << ")"; - } - } + for (int64_t i = 0, e = resultType.getRank(); i < e; ++i) { + int64_t actualDimSize = resultType.getDimSize(i); + int64_t expectedDimSize = expectedResultType.getDimSize(i); - return mlir::success(); + if (actualDimSize != ArrayType::kDynamic && + actualDimSize != expectedDimSize) { + return emitOpError() << "incompatible size for dimension " << i + << " (expected " << expectedDimSize << ", got " + << actualDimSize << ")"; + } } - void SubscriptionOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add< - InferSubscriptionResultTypePattern, - MergeSubscriptionsPattern>(context); - } + return mlir::success(); +} - ArrayType SubscriptionOp::inferResultType( - ArrayType source, mlir::ValueRange indices) - { - llvm::SmallVector shape; - size_t numOfSubscriptions = indices.size(); +void SubscriptionOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add( + context); +} - for (size_t i = 0; i < numOfSubscriptions; ++i) { - mlir::Value index = indices[i]; +ArrayType SubscriptionOp::inferResultType(ArrayType source, + mlir::ValueRange indices) { + llvm::SmallVector shape; + size_t numOfSubscriptions = indices.size(); - if (index.getType().isa()) { - int64_t dimension = ArrayType::kDynamic; + for (size_t i = 0; i < numOfSubscriptions; ++i) { + mlir::Value index = indices[i]; - if (auto constantOp = index.getDefiningOp()) { - auto indexAttr = constantOp.getValue(); + if (index.getType().isa()) { + int64_t dimension = ArrayType::kDynamic; - if (auto rangeAttr = mlir::dyn_cast(indexAttr)) { - dimension = rangeAttr.getNumOfElements(); - } - } + if (auto constantOp = index.getDefiningOp()) { + auto indexAttr = constantOp.getValue(); - shape.push_back(dimension); + if (auto rangeAttr = mlir::dyn_cast(indexAttr)) { + dimension = rangeAttr.getNumOfElements(); + } } - } - for (int64_t dimension : - source.getShape().drop_front(numOfSubscriptions)) { shape.push_back(dimension); } + } - return source.withShape(shape); + for (int64_t dimension : source.getShape().drop_front(numOfSubscriptions)) { + shape.push_back(dimension); } + + return source.withShape(shape); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ArrayFillOp -namespace mlir::bmodelica -{ - void ArrayFillOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getArray(), - mlir::SideEffects::DefaultResource::get()); - } +namespace mlir::bmodelica { +void ArrayFillOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Write::get(), getArray(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ArrayCopyOp -namespace mlir::bmodelica -{ - void ArrayCopyOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Read::get(), - getSource(), - mlir::SideEffects::DefaultResource::get()); +namespace mlir::bmodelica { +void ArrayCopyOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Read::get(), getSource(), + mlir::SideEffects::DefaultResource::get()); - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getDestination(), - mlir::SideEffects::DefaultResource::get()); - } + effects.emplace_back(mlir::MemoryEffects::Write::get(), getDestination(), + mlir::SideEffects::DefaultResource::get()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Variable operations @@ -1538,869 +1360,768 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // VariableOp -namespace mlir::bmodelica -{ - void VariableOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name, - VariableType variableType) - { - llvm::SmallVector constraints( - variableType.getNumDynamicDims(), - builder.getStringAttr(kDimensionConstraintUnbounded)); - - build(builder, state, name, variableType, - builder.getArrayAttr(constraints)); - } - - mlir::ParseResult VariableOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto& builder = parser.getBuilder(); - - // Variable name. - mlir::StringAttr nameAttr; - - if (parser.parseSymbolName( - nameAttr, - mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +void VariableOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, VariableType variableType) { + llvm::SmallVector constraints( + variableType.getNumDynamicDims(), + builder.getStringAttr(kDimensionConstraintUnbounded)); - // Attributes. - if (parser.parseOptionalAttrDict(result.attributes)) { - return mlir::failure(); - } + build(builder, state, name, variableType, builder.getArrayAttr(constraints)); +} - // Variable type. - mlir::Type type; +mlir::ParseResult VariableOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); - if (parser.parseColonType(type)) { - return mlir::failure(); - } + // Variable name. + mlir::StringAttr nameAttr; - result.attributes.append( - getTypeAttrName(result.name), - mlir::TypeAttr::get(type)); + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) { + return mlir::failure(); + } - // Dimensions constraints. - llvm::SmallVector dimensionsConstraints; + // Attributes. + if (parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); + } - if (mlir::succeeded(parser.parseOptionalLSquare())) { - do { - if (mlir::succeeded( - parser.parseOptionalKeyword(kDimensionConstraintUnbounded))) { - dimensionsConstraints.push_back(kDimensionConstraintUnbounded); - } else { - if (parser.parseKeyword(kDimensionConstraintFixed)) { - return mlir::failure(); - } + // Variable type. + mlir::Type type; - dimensionsConstraints.push_back(kDimensionConstraintFixed); - } - } while (mlir::succeeded(parser.parseOptionalComma())); + if (parser.parseColonType(type)) { + return mlir::failure(); + } - if (parser.parseRSquare()) { - return mlir::failure(); - } - } + result.attributes.append(getTypeAttrName(result.name), + mlir::TypeAttr::get(type)); - result.attributes.append( - getDimensionsConstraintsAttrName(result.name), - builder.getStrArrayAttr(dimensionsConstraints)); + // Dimensions constraints. + llvm::SmallVector dimensionsConstraints; - // Region for the dimensions constraints. - mlir::Region* constraintsRegion = result.addRegion(); + if (mlir::succeeded(parser.parseOptionalLSquare())) { + do { + if (mlir::succeeded( + parser.parseOptionalKeyword(kDimensionConstraintUnbounded))) { + dimensionsConstraints.push_back(kDimensionConstraintUnbounded); + } else { + if (parser.parseKeyword(kDimensionConstraintFixed)) { + return mlir::failure(); + } - mlir::OptionalParseResult constraintsRegionParseResult = - parser.parseOptionalRegion(*constraintsRegion); + dimensionsConstraints.push_back(kDimensionConstraintFixed); + } + } while (mlir::succeeded(parser.parseOptionalComma())); - if (constraintsRegionParseResult.has_value() && - failed(*constraintsRegionParseResult)) { + if (parser.parseRSquare()) { return mlir::failure(); } + } - return mlir::success(); + result.attributes.append(getDimensionsConstraintsAttrName(result.name), + builder.getStrArrayAttr(dimensionsConstraints)); + + // Region for the dimensions constraints. + mlir::Region *constraintsRegion = result.addRegion(); + + mlir::OptionalParseResult constraintsRegionParseResult = + parser.parseOptionalRegion(*constraintsRegion); + + if (constraintsRegionParseResult.has_value() && + failed(*constraintsRegionParseResult)) { + return mlir::failure(); } - void VariableOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printSymbolName(getSymName()); + return mlir::success(); +} - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); - elidedAttrs.push_back(getTypeAttrName()); - elidedAttrs.push_back(getDimensionsConstraintsAttrName()); +void VariableOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getSymName()); - printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); + elidedAttrs.push_back(getTypeAttrName()); + elidedAttrs.push_back(getDimensionsConstraintsAttrName()); - printer << " : " << getType(); + printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - auto dimConstraints = - getDimensionsConstraints().getAsRange(); + printer << " : " << getType(); - if (llvm::any_of(dimConstraints, [](mlir::StringAttr constraint) { - return constraint == kDimensionConstraintFixed; - })) { - printer << " ["; + auto dimConstraints = + getDimensionsConstraints().getAsRange(); - for (const auto& constraint : llvm::enumerate(dimConstraints)) { - if (constraint.index() != 0) { - printer << ", "; - } + if (llvm::any_of(dimConstraints, [](mlir::StringAttr constraint) { + return constraint == kDimensionConstraintFixed; + })) { + printer << " ["; - printer << constraint.value().getValue(); + for (const auto &constraint : llvm::enumerate(dimConstraints)) { + if (constraint.index() != 0) { + printer << ", "; } - printer << "] "; + printer << constraint.value().getValue(); } - if (mlir::Region& region = getConstraintsRegion(); !region.empty()) { - printer.printRegion(region); - } + printer << "] "; } - mlir::LogicalResult VariableOp::verify() - { - // Verify the semantics for fixed dimensions constraints. - size_t numOfFixedDims = getNumOfFixedDimensions(); - mlir::Region& constraintsRegion = getConstraintsRegion(); - size_t numOfConstraints = 0; + if (mlir::Region ®ion = getConstraintsRegion(); !region.empty()) { + printer.printRegion(region); + } +} - if (!constraintsRegion.empty()) { - auto yieldOp = mlir::cast( - constraintsRegion.back().getTerminator()); +mlir::LogicalResult VariableOp::verify() { + // Verify the semantics for fixed dimensions constraints. + size_t numOfFixedDims = getNumOfFixedDimensions(); + mlir::Region &constraintsRegion = getConstraintsRegion(); + size_t numOfConstraints = 0; - numOfConstraints = yieldOp.getValues().size(); - } + if (!constraintsRegion.empty()) { + auto yieldOp = + mlir::cast(constraintsRegion.back().getTerminator()); - if (numOfFixedDims != numOfConstraints) { - return emitOpError( - "not enough constraints for dynamic dimension constraints have been " - "provided (expected " + std::to_string(numOfFixedDims) + ", got " + - std::to_string(numOfConstraints) + ")"); - } + numOfConstraints = yieldOp.getValues().size(); + } - if (!constraintsRegion.empty()) { - auto yieldOp = mlir::cast( - constraintsRegion.back().getTerminator()); + if (numOfFixedDims != numOfConstraints) { + return emitOpError( + "not enough constraints for dynamic dimension constraints have been " + "provided (expected " + + std::to_string(numOfFixedDims) + ", got " + + std::to_string(numOfConstraints) + ")"); + } - // Check that the amount of values is the same of the fixed dimensions. - if (yieldOp.getValues().size() != getNumOfFixedDimensions()) { - return mlir::failure(); - } + if (!constraintsRegion.empty()) { + auto yieldOp = + mlir::cast(constraintsRegion.back().getTerminator()); - // Check that all the dimensions have 'index' type. - if (llvm::any_of(yieldOp.getValues(), [](mlir::Value value) { - return !value.getType().isa(); - })) { - return emitOpError( - "constraints for dynamic dimensions must have 'index' type"); - } + // Check that the amount of values is the same of the fixed dimensions. + if (yieldOp.getValues().size() != getNumOfFixedDimensions()) { + return mlir::failure(); } - return mlir::success(); + // Check that all the dimensions have 'index' type. + if (llvm::any_of(yieldOp.getValues(), [](mlir::Value value) { + return !value.getType().isa(); + })) { + return emitOpError( + "constraints for dynamic dimensions must have 'index' type"); + } } - VariableType VariableOp::getVariableType() - { - return getType().cast(); - } + return mlir::success(); +} - bool VariableOp::isInput() - { - return getVariableType().isInput(); - } +VariableType VariableOp::getVariableType() { + return getType().cast(); +} - bool VariableOp::isOutput() - { - return getVariableType().isOutput(); - } +bool VariableOp::isInput() { return getVariableType().isInput(); } - bool VariableOp::isDiscrete() - { - return getVariableType().isDiscrete(); - } +bool VariableOp::isOutput() { return getVariableType().isOutput(); } - bool VariableOp::isParameter() - { - return getVariableType().isParameter(); - } +bool VariableOp::isDiscrete() { return getVariableType().isDiscrete(); } - bool VariableOp::isConstant() - { - return getVariableType().isConstant(); - } +bool VariableOp::isParameter() { return getVariableType().isParameter(); } - bool VariableOp::isReadOnly() - { - return getVariableType().isReadOnly(); - } +bool VariableOp::isConstant() { return getVariableType().isConstant(); } - size_t VariableOp::getNumOfUnboundedDimensions() - { - return llvm::count_if( - getDimensionsConstraints().getAsRange(), - [](mlir::StringAttr dimensionConstraint) { - return dimensionConstraint.getValue() == - kDimensionConstraintUnbounded; - }); - } +bool VariableOp::isReadOnly() { return getVariableType().isReadOnly(); } - size_t VariableOp::getNumOfFixedDimensions() - { - return llvm::count_if( - getDimensionsConstraints().getAsRange(), - [](mlir::StringAttr dimensionConstraint) { - return dimensionConstraint.getValue() == - kDimensionConstraintFixed; - }); - } +size_t VariableOp::getNumOfUnboundedDimensions() { + return llvm::count_if( + getDimensionsConstraints().getAsRange(), + [](mlir::StringAttr dimensionConstraint) { + return dimensionConstraint.getValue() == kDimensionConstraintUnbounded; + }); +} - IndexSet VariableOp::getIndices() - { - VariableType variableType = getVariableType(); +size_t VariableOp::getNumOfFixedDimensions() { + return llvm::count_if( + getDimensionsConstraints().getAsRange(), + [](mlir::StringAttr dimensionConstraint) { + return dimensionConstraint.getValue() == kDimensionConstraintFixed; + }); +} - if (variableType.isScalar()) { - return {}; - } +IndexSet VariableOp::getIndices() { + VariableType variableType = getVariableType(); - llvm::SmallVector ranges; + if (variableType.isScalar()) { + return {}; + } - for (int64_t dimension : variableType.getShape()) { - ranges.push_back(Range(0, dimension)); - } + llvm::SmallVector ranges; - return IndexSet(MultidimensionalRange(ranges)); + for (int64_t dimension : variableType.getShape()) { + ranges.push_back(Range(0, dimension)); } + + return IndexSet(MultidimensionalRange(ranges)); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // VariableGetOp -namespace mlir::bmodelica -{ - void VariableGetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp) - { - auto variableType = variableOp.getVariableType(); - auto variableName = variableOp.getSymName(); - build(builder, state, variableType.unwrap(), variableName); - } - - mlir::LogicalResult VariableGetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - auto parentClass = getOperation()->getParentOfType(); +namespace mlir::bmodelica { +void VariableGetOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + VariableOp variableOp) { + auto variableType = variableOp.getVariableType(); + auto variableName = variableOp.getSymName(); + build(builder, state, variableType.unwrap(), variableName); +} - if (!parentClass) { - return emitOpError() << "the operation must be used inside a class"; - } +mlir::LogicalResult VariableGetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + auto parentClass = getOperation()->getParentOfType(); - mlir::Operation* symbol = - symbolTableCollection.lookupSymbolIn(parentClass, getVariableAttr()); + if (!parentClass) { + return emitOpError() << "the operation must be used inside a class"; + } - if (!symbol) { - return emitOpError() - << "variable " << getVariable() << " has not been declared"; - } + mlir::Operation *symbol = + symbolTableCollection.lookupSymbolIn(parentClass, getVariableAttr()); - if (!mlir::isa(symbol)) { - return emitOpError() - << "symbol " << getVariable() << " is not a variable"; - } + if (!symbol) { + return emitOpError() << "variable " << getVariable() + << " has not been declared"; + } - auto variableOp = mlir::cast(symbol); - mlir::Type unwrappedType = variableOp.getVariableType().unwrap(); + if (!mlir::isa(symbol)) { + return emitOpError() << "symbol " << getVariable() << " is not a variable"; + } - if (unwrappedType != getResult().getType()) { - return emitOpError() << "result type does not match the variable type"; - } + auto variableOp = mlir::cast(symbol); + mlir::Type unwrappedType = variableOp.getVariableType().unwrap(); - return mlir::success(); + if (unwrappedType != getResult().getType()) { + return emitOpError() << "result type does not match the variable type"; } + + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // VariableSetOp -namespace mlir::bmodelica -{ - void VariableSetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp, - mlir::Value value) - { - auto variableName = variableOp.getSymName(); - build(builder, state, variableName, std::nullopt, value); - } - - void VariableSetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp, - mlir::ValueRange indices, - mlir::Value value) - { - auto variableName = variableOp.getSymName(); - build(builder, state, variableName, indices, value); - } - - mlir::ParseResult VariableSetOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - - mlir::StringAttr variableAttr; - llvm::SmallVector indices; - mlir::OpAsmParser::UnresolvedOperand value; - llvm::SmallVector types; - - if (parser.parseSymbolName(variableAttr)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +void VariableSetOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + VariableOp variableOp, mlir::Value value) { + auto variableName = variableOp.getSymName(); + build(builder, state, variableName, std::nullopt, value); +} - if (variableAttr) { - result.getOrAddProperties().variable = variableAttr; - } +void VariableSetOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + VariableOp variableOp, mlir::ValueRange indices, + mlir::Value value) { + auto variableName = variableOp.getSymName(); + build(builder, state, variableName, indices, value); +} - if (mlir::succeeded(parser.parseOptionalLSquare())) { - do { - if (parser.parseOperand(indices.emplace_back())) { - return mlir::failure(); - } - } while (mlir::succeeded(parser.parseOptionalComma())); +mlir::ParseResult VariableSetOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); - if (parser.parseRSquare()) { + mlir::StringAttr variableAttr; + llvm::SmallVector indices; + mlir::OpAsmParser::UnresolvedOperand value; + llvm::SmallVector types; + + if (parser.parseSymbolName(variableAttr)) { + return mlir::failure(); + } + + if (variableAttr) { + result.getOrAddProperties().variable = + variableAttr; + } + + if (mlir::succeeded(parser.parseOptionalLSquare())) { + do { + if (parser.parseOperand(indices.emplace_back())) { return mlir::failure(); } - } + } while (mlir::succeeded(parser.parseOptionalComma())); - if (parser.parseComma() || - parser.parseOperand(value) || - parser.parseColonTypeList(types)) { + if (parser.parseRSquare()) { return mlir::failure(); } + } - if (types.size() != indices.size() + 1) { - return mlir::failure(); - } + if (parser.parseComma() || parser.parseOperand(value) || + parser.parseColonTypeList(types)) { + return mlir::failure(); + } - if (!indices.empty()) { - if (parser.resolveOperands( - indices, llvm::ArrayRef(types).drop_back(), loc, - result.operands)) { - return mlir::failure(); - } - } + if (types.size() != indices.size() + 1) { + return mlir::failure(); + } - if (parser.resolveOperand(value, types.back(), result.operands)) { + if (!indices.empty()) { + if (parser.resolveOperands(indices, llvm::ArrayRef(types).drop_back(), loc, + result.operands)) { return mlir::failure(); } + } - return mlir::success(); + if (parser.resolveOperand(value, types.back(), result.operands)) { + return mlir::failure(); } - void VariableSetOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printSymbolName(getVariable()); + return mlir::success(); +} - if (auto indices = getIndices(); !indices.empty()) { - printer << "[" << indices << "]"; - } +void VariableSetOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getVariable()); - printer << ", " << getValue() << " : "; + if (auto indices = getIndices(); !indices.empty()) { + printer << "[" << indices << "]"; + } - if (auto indices = getIndices(); !indices.empty()) { - printer << indices.getTypes() << ", "; - } + printer << ", " << getValue() << " : "; - printer << getValue().getType(); + if (auto indices = getIndices(); !indices.empty()) { + printer << indices.getTypes() << ", "; } - mlir::LogicalResult VariableSetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - auto parentClass = getOperation()->getParentOfType(); + printer << getValue().getType(); +} - if (!parentClass) { - return emitOpError("the operation must be used inside a class"); - } +mlir::LogicalResult VariableSetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + auto parentClass = getOperation()->getParentOfType(); - mlir::Operation* symbol = - symbolTableCollection.lookupSymbolIn(parentClass, getVariableAttr()); + if (!parentClass) { + return emitOpError("the operation must be used inside a class"); + } - if (!symbol) { - return emitOpError( - "variable " + getVariable() + " has not been declared"); - } + mlir::Operation *symbol = + symbolTableCollection.lookupSymbolIn(parentClass, getVariableAttr()); - auto variableOp = mlir::dyn_cast(symbol); + if (!symbol) { + return emitOpError("variable " + getVariable() + " has not been declared"); + } - if (!variableOp) { - return emitOpError("symbol " + getVariable() + " is not a variable"); - } + auto variableOp = mlir::dyn_cast(symbol); - if (variableOp.isInput()) { - return emitOpError("can't set a value for an input variable"); - } + if (!variableOp) { + return emitOpError("symbol " + getVariable() + " is not a variable"); + } - return mlir::success(); + if (variableOp.isInput()) { + return emitOpError("can't set a value for an input variable"); } + + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // VariableComponentSetOp -namespace mlir::bmodelica -{ - mlir::ParseResult VariableComponentSetOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - - llvm::SmallVector path; - llvm::SmallVector subscripts; - llvm::SmallVector subscriptsAmounts; - mlir::OpAsmParser::UnresolvedOperand value; - llvm::SmallVector types; +namespace mlir::bmodelica { +mlir::ParseResult VariableComponentSetOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); - do { - mlir::StringAttr component; + llvm::SmallVector path; + llvm::SmallVector subscripts; + llvm::SmallVector subscriptsAmounts; + mlir::OpAsmParser::UnresolvedOperand value; + llvm::SmallVector types; - if (parser.parseSymbolName(component)) { - return mlir::failure(); - } + do { + mlir::StringAttr component; - path.push_back(mlir::FlatSymbolRefAttr::get(component)); + if (parser.parseSymbolName(component)) { + return mlir::failure(); + } - llvm::SmallVector< - mlir::OpAsmParser::UnresolvedOperand, 3> componentSubscripts; + path.push_back(mlir::FlatSymbolRefAttr::get(component)); - if (mlir::succeeded(parser.parseOptionalLSquare())) { - do { - if (parser.parseOperand(componentSubscripts.emplace_back())) { - return mlir::failure(); - } - } while (mlir::succeeded(parser.parseOptionalComma())); + llvm::SmallVector + componentSubscripts; - if (parser.parseRSquare()) { + if (mlir::succeeded(parser.parseOptionalLSquare())) { + do { + if (parser.parseOperand(componentSubscripts.emplace_back())) { return mlir::failure(); } - } + } while (mlir::succeeded(parser.parseOptionalComma())); - subscriptsAmounts.push_back( - static_cast(componentSubscripts.size())); + if (parser.parseRSquare()) { + return mlir::failure(); + } + } - subscripts.append(componentSubscripts); - } while (mlir::succeeded(parser.parseOptionalColon()) && - mlir::succeeded(parser.parseOptionalColon())); + subscriptsAmounts.push_back( + static_cast(componentSubscripts.size())); - result.getOrAddProperties().path = - parser.getBuilder().getArrayAttr(path); + subscripts.append(componentSubscripts); + } while (mlir::succeeded(parser.parseOptionalColon()) && + mlir::succeeded(parser.parseOptionalColon())); - result.getOrAddProperties< - VariableComponentSetOp::Properties>().subscriptionsAmounts = - parser.getBuilder().getI64ArrayAttr(subscriptsAmounts); + result.getOrAddProperties().path = + parser.getBuilder().getArrayAttr(path); - if (parser.parseComma() || - parser.parseOperand(value) || - parser.parseColonTypeList(types)) { - return mlir::failure(); - } + result.getOrAddProperties() + .subscriptionsAmounts = + parser.getBuilder().getI64ArrayAttr(subscriptsAmounts); - if (!subscripts.empty()) { - if (parser.resolveOperands( - subscripts, llvm::ArrayRef(types).drop_back(), loc, - result.operands)) { - return mlir::failure(); - } - } + if (parser.parseComma() || parser.parseOperand(value) || + parser.parseColonTypeList(types)) { + return mlir::failure(); + } - if (parser.resolveOperand(value, types.back(), result.operands)) { + if (!subscripts.empty()) { + if (parser.resolveOperands(subscripts, llvm::ArrayRef(types).drop_back(), + loc, result.operands)) { return mlir::failure(); } + } - return mlir::success(); + if (parser.resolveOperand(value, types.back(), result.operands)) { + return mlir::failure(); } - void VariableComponentSetOp::print(mlir::OpAsmPrinter& printer) - { - size_t pathLength = getPath().size(); - auto subscriptions = getSubscriptions(); - auto subscriptionsAmount = getSubscriptionsAmounts(); - size_t subscriptsPos = 0; + return mlir::success(); +} - printer << " "; +void VariableComponentSetOp::print(mlir::OpAsmPrinter &printer) { + size_t pathLength = getPath().size(); + printer << " "; - for (size_t component = 0; component < pathLength; ++component) { - if (component != 0) { - printer << "::"; - } + for (size_t component = 0; component < pathLength; ++component) { + if (component != 0) { + printer << "::"; + } - printer << getPath()[component]; + printer << getPath()[component]; - if (auto subscripts = getComponentSubscripts(component); !subscripts.empty()) { - printer << "["; - llvm::interleaveComma(subscripts, printer); - printer << "]"; - } + if (auto subscripts = getComponentSubscripts(component); + !subscripts.empty()) { + printer << "["; + llvm::interleaveComma(subscripts, printer); + printer << "]"; } + } - printer << ", " << getValue() << " : "; + printer << ", " << getValue() << " : "; - if (auto subscripts = getSubscriptions(); !subscripts.empty()) { - for (mlir::Value subscript : subscripts) { - printer << subscript.getType() << ", "; - } + if (auto subscripts = getSubscriptions(); !subscripts.empty()) { + for (mlir::Value subscript : subscripts) { + printer << subscript.getType() << ", "; } - - printer << getValue().getType(); } - mlir::ValueRange VariableComponentSetOp::getComponentSubscripts( - size_t componentIndex) - { - auto subscripts = getSubscriptions(); + printer << getValue().getType(); +} - if (subscripts.empty()) { - return std::nullopt; - } +mlir::ValueRange +VariableComponentSetOp::getComponentSubscripts(size_t componentIndex) { + auto subscripts = getSubscriptions(); - auto numOfSubscripts = getSubscriptionsAmounts()[componentIndex] - .cast().getInt(); + if (subscripts.empty()) { + return std::nullopt; + } - if (numOfSubscripts == 0) { - return std::nullopt; - } + auto numOfSubscripts = getSubscriptionsAmounts()[componentIndex] + .cast() + .getInt(); - size_t beginPos = 0; + if (numOfSubscripts == 0) { + return std::nullopt; + } - for (size_t i = 0; i < componentIndex; ++i) { - beginPos += - getSubscriptionsAmounts()[i].cast().getInt(); - } + size_t beginPos = 0; - return subscripts.slice(beginPos, numOfSubscripts); + for (size_t i = 0; i < componentIndex; ++i) { + beginPos += getSubscriptionsAmounts()[i].cast().getInt(); } + + return subscripts.slice(beginPos, numOfSubscripts); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ComponentGetOp -namespace mlir::bmodelica -{ - mlir::LogicalResult ComponentGetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - mlir::Type variableType = getVariable().getType(); - mlir::Type recordType = variableType; +namespace mlir::bmodelica { +mlir::LogicalResult ComponentGetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + mlir::Type variableType = getVariable().getType(); + mlir::Type recordType = variableType; - if (auto tensorType = recordType.dyn_cast()) { - recordType = tensorType.getElementType(); - } + if (auto tensorType = recordType.dyn_cast()) { + recordType = tensorType.getElementType(); + } - auto moduleOp = getOperation()->getParentOfType(); + auto moduleOp = getOperation()->getParentOfType(); - auto recordOp = mlir::dyn_cast( - recordType.cast() - .getRecordOp(symbolTableCollection, moduleOp)); + auto recordOp = + mlir::dyn_cast(recordType.cast().getRecordOp( + symbolTableCollection, moduleOp)); - if (!recordOp) { - return emitOpError() << "Can't resolve record type"; - } + if (!recordOp) { + return emitOpError() << "Can't resolve record type"; + } - VariableOp componentVariable = nullptr; + VariableOp componentVariable = nullptr; - for (auto variable : recordOp.getVariables()) { - if (variable.getSymName() == getComponentName()) { - componentVariable = variable; - break; - } + for (auto variable : recordOp.getVariables()) { + if (variable.getSymName() == getComponentName()) { + componentVariable = variable; + break; } + } - if (!componentVariable) { - return emitOpError() << "Can't resolve record component"; - } + if (!componentVariable) { + return emitOpError() << "Can't resolve record component"; + } - llvm::SmallVector expectedResultShape; + llvm::SmallVector expectedResultShape; - if (auto variableShapedType = variableType.dyn_cast()) { - auto variableShape = variableShapedType.getShape(); - expectedResultShape.append(variableShape.begin(), variableShape.end()); - } + if (auto variableShapedType = variableType.dyn_cast()) { + auto variableShape = variableShapedType.getShape(); + expectedResultShape.append(variableShape.begin(), variableShape.end()); + } - if (auto componentShapedType = - componentVariable.getType().dyn_cast()) { - auto componentShape = componentShapedType.getShape(); - expectedResultShape.append(componentShape.begin(), componentShape.end()); - } + if (auto componentShapedType = + componentVariable.getType().dyn_cast()) { + auto componentShape = componentShapedType.getShape(); + expectedResultShape.append(componentShape.begin(), componentShape.end()); + } - mlir::Type expectedResultType = - componentVariable.getVariableType().unwrap(); + mlir::Type expectedResultType = componentVariable.getVariableType().unwrap(); - if (!expectedResultShape.empty()) { - if (auto expectedResultShapedType = - mlir::dyn_cast(expectedResultType)) { - expectedResultType = - expectedResultShapedType.clone(expectedResultShape); - } else { - expectedResultType = mlir::RankedTensorType::get( - expectedResultShape, expectedResultType); - } + if (!expectedResultShape.empty()) { + if (auto expectedResultShapedType = + mlir::dyn_cast(expectedResultType)) { + expectedResultType = expectedResultShapedType.clone(expectedResultShape); + } else { + expectedResultType = + mlir::RankedTensorType::get(expectedResultShape, expectedResultType); } + } - mlir::Type resultType = getResult().getType(); - - if (resultType != expectedResultType) { - return emitOpError() << "Incompatible result type. Expected " - << expectedResultType << ", got " << resultType; - } + mlir::Type resultType = getResult().getType(); - return mlir::success(); + if (resultType != expectedResultType) { + return emitOpError() << "Incompatible result type. Expected " + << expectedResultType << ", got " << resultType; } + + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // GlobalVariableOp -namespace mlir::bmodelica -{ - void GlobalVariableOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - mlir::StringAttr name, - mlir::TypeAttr type) - { - build(builder, state, name, type, nullptr); - } +namespace mlir::bmodelica { +void GlobalVariableOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, mlir::StringAttr name, + mlir::TypeAttr type) { + build(builder, state, name, type, nullptr); +} - void GlobalVariableOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name, - mlir::Type type) - { - build(builder, state, name, type, nullptr); - } +void GlobalVariableOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, llvm::StringRef name, + mlir::Type type) { + build(builder, state, name, type, nullptr); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // GlobalVariableGetOp -namespace mlir::bmodelica -{ - void GlobalVariableGetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - GlobalVariableOp globalVariableOp) - { - auto type = globalVariableOp.getType(); - auto name = globalVariableOp.getSymName(); - build(builder, state, type, name); - } - - mlir::LogicalResult GlobalVariableGetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - auto moduleOp = getOperation()->getParentOfType(); +namespace mlir::bmodelica { +void GlobalVariableGetOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + GlobalVariableOp globalVariableOp) { + auto type = globalVariableOp.getType(); + auto name = globalVariableOp.getSymName(); + build(builder, state, type, name); +} - mlir::Operation* symbol = - symbolTableCollection.lookupSymbolIn(moduleOp, getVariableAttr()); +mlir::LogicalResult GlobalVariableGetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + auto moduleOp = getOperation()->getParentOfType(); - if (!symbol) { - return emitOpError() - << "global variable " << getVariable() << " has not been declared"; - } + mlir::Operation *symbol = + symbolTableCollection.lookupSymbolIn(moduleOp, getVariableAttr()); - if (!mlir::isa(symbol)) { - return emitOpError() - << "symbol " << getVariable() << " is not a global variable"; - } + if (!symbol) { + return emitOpError() << "global variable " << getVariable() + << " has not been declared"; + } - auto globalVariableOp = mlir::cast(symbol); + if (!mlir::isa(symbol)) { + return emitOpError() << "symbol " << getVariable() + << " is not a global variable"; + } - if (globalVariableOp.getType() != getResult().getType()) { - return emitOpError() - << "result type does not match the global variable type"; - } + auto globalVariableOp = mlir::cast(symbol); - return mlir::success(); + if (globalVariableOp.getType() != getResult().getType()) { + return emitOpError() + << "result type does not match the global variable type"; } + + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // QualifiedVariableGetOp -namespace mlir::bmodelica -{ - void QualifiedVariableGetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp) - { - auto variableType = variableOp.getVariableType(); - auto qualifiedRef = getSymbolRefFromRoot(variableOp); - build(builder, state, variableType.unwrap(), qualifiedRef); - } - - mlir::LogicalResult QualifiedVariableGetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - // TODO - return mlir::success(); - } +namespace mlir::bmodelica { +void QualifiedVariableGetOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + VariableOp variableOp) { + auto variableType = variableOp.getVariableType(); + auto qualifiedRef = getSymbolRefFromRoot(variableOp); + build(builder, state, variableType.unwrap(), qualifiedRef); +} - VariableOp QualifiedVariableGetOp::getVariableOp() - { - mlir::SymbolTableCollection symbolTableCollection; - return getVariableOp(symbolTableCollection); - } +mlir::LogicalResult QualifiedVariableGetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + // TODO + return mlir::success(); +} + +VariableOp QualifiedVariableGetOp::getVariableOp() { + mlir::SymbolTableCollection symbolTableCollection; + return getVariableOp(symbolTableCollection); +} - VariableOp QualifiedVariableGetOp::getVariableOp( - mlir::SymbolTableCollection& symbolTableCollection) - { - mlir::ModuleOp moduleOp = - getOperation()->getParentOfType(); +VariableOp QualifiedVariableGetOp::getVariableOp( + mlir::SymbolTableCollection &symbolTableCollection) { + mlir::ModuleOp moduleOp = getOperation()->getParentOfType(); - mlir::Operation* variable = - resolveSymbol(moduleOp, symbolTableCollection, getVariable()); + mlir::Operation *variable = + resolveSymbol(moduleOp, symbolTableCollection, getVariable()); - return mlir::dyn_cast(variable); - } + return mlir::dyn_cast(variable); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// -//QualifiedVariableSetOp +// QualifiedVariableSetOp -namespace mlir::bmodelica -{ - void QualifiedVariableSetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp, - mlir::Value value) - { - build(builder, state, variableOp, std::nullopt, value); - } - - void QualifiedVariableSetOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - VariableOp variableOp, - mlir::ValueRange indices, - mlir::Value value) - { - auto qualifiedRef = getSymbolRefFromRoot(variableOp); - build(builder, state, qualifiedRef, indices, value); - } - - mlir::ParseResult QualifiedVariableSetOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - - mlir::SymbolRefAttr variableAttr; - llvm::SmallVector indices; - mlir::OpAsmParser::UnresolvedOperand value; - llvm::SmallVector types; - - if (parser.parseCustomAttributeWithFallback( - variableAttr, parser.getBuilder().getType())) { - return mlir::failure(); - } +namespace mlir::bmodelica { +void QualifiedVariableSetOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + VariableOp variableOp, mlir::Value value) { + build(builder, state, variableOp, std::nullopt, value); +} - if (variableAttr) { - result.getOrAddProperties< - QualifiedVariableSetOp::Properties>().variable = variableAttr; - } +void QualifiedVariableSetOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + VariableOp variableOp, + mlir::ValueRange indices, + mlir::Value value) { + auto qualifiedRef = getSymbolRefFromRoot(variableOp); + build(builder, state, qualifiedRef, indices, value); +} - if (mlir::succeeded(parser.parseOptionalLSquare())) { - do { - if (parser.parseOperand(indices.emplace_back())) { - return mlir::failure(); - } - } while (mlir::succeeded(parser.parseOptionalComma())); +mlir::ParseResult QualifiedVariableSetOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); - if (parser.parseRSquare()) { - return mlir::failure(); - } - } + mlir::SymbolRefAttr variableAttr; + llvm::SmallVector indices; + mlir::OpAsmParser::UnresolvedOperand value; + llvm::SmallVector types; - if (parser.parseComma() || - parser.parseOperand(value) || - parser.parseColonTypeList(types)) { - return mlir::failure(); - } + if (parser.parseCustomAttributeWithFallback( + variableAttr, parser.getBuilder().getType())) { + return mlir::failure(); + } - if (types.size() != indices.size() + 1) { - return mlir::failure(); - } + if (variableAttr) { + result.getOrAddProperties().variable = + variableAttr; + } - if (!indices.empty()) { - if (parser.resolveOperands( - indices, llvm::ArrayRef(types).drop_back(), loc, - result.operands)) { + if (mlir::succeeded(parser.parseOptionalLSquare())) { + do { + if (parser.parseOperand(indices.emplace_back())) { return mlir::failure(); } - } + } while (mlir::succeeded(parser.parseOptionalComma())); - if (parser.resolveOperand(value, types.back(), result.operands)) { + if (parser.parseRSquare()) { return mlir::failure(); } + } - return mlir::success(); + if (parser.parseComma() || parser.parseOperand(value) || + parser.parseColonTypeList(types)) { + return mlir::failure(); } - void QualifiedVariableSetOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << getVariable(); + if (types.size() != indices.size() + 1) { + return mlir::failure(); + } - if (auto indices = getIndices(); !indices.empty()) { - printer << "[" << indices << "]"; + if (!indices.empty()) { + if (parser.resolveOperands(indices, llvm::ArrayRef(types).drop_back(), loc, + result.operands)) { + return mlir::failure(); } + } - printer << ", " << getValue() << " : "; + if (parser.resolveOperand(value, types.back(), result.operands)) { + return mlir::failure(); + } - if (auto indices = getIndices(); !indices.empty()) { - printer << indices.getTypes() << ", "; - } + return mlir::success(); +} - printer << getValue().getType(); - } +void QualifiedVariableSetOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << getVariable(); - mlir::LogicalResult QualifiedVariableSetOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - // TODO - return mlir::success(); + if (auto indices = getIndices(); !indices.empty()) { + printer << "[" << indices << "]"; } - VariableOp QualifiedVariableSetOp::getVariableOp() - { - mlir::SymbolTableCollection symbolTableCollection; - return getVariableOp(symbolTableCollection); + printer << ", " << getValue() << " : "; + + if (auto indices = getIndices(); !indices.empty()) { + printer << indices.getTypes() << ", "; } - VariableOp QualifiedVariableSetOp::getVariableOp( - mlir::SymbolTableCollection& symbolTableCollection) - { - mlir::ModuleOp moduleOp = - getOperation()->getParentOfType(); + printer << getValue().getType(); +} + +mlir::LogicalResult QualifiedVariableSetOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + // TODO + return mlir::success(); +} - mlir::Operation* variable = - resolveSymbol(moduleOp, symbolTableCollection, getVariable()); +VariableOp QualifiedVariableSetOp::getVariableOp() { + mlir::SymbolTableCollection symbolTableCollection; + return getVariableOp(symbolTableCollection); +} - return mlir::dyn_cast(variable); - } +VariableOp QualifiedVariableSetOp::getVariableOp( + mlir::SymbolTableCollection &symbolTableCollection) { + mlir::ModuleOp moduleOp = getOperation()->getParentOfType(); + + mlir::Operation *variable = + resolveSymbol(moduleOp, symbolTableCollection, getVariable()); + + return mlir::dyn_cast(variable); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Math operations @@ -2409,3245 +2130,3088 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ConstantOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) - { - return getValue().cast(); - } +namespace mlir::bmodelica { +mlir::OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { + return getValue().cast(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // NegateOp -namespace mlir::bmodelica -{ - mlir::LogicalResult NegateOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type operandType = adaptor.getOperand().getType(); - - if (isScalar(operandType)) { - returnTypes.push_back(operandType); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult NegateOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type operandType = adaptor.getOperand().getType(); + + if (isScalar(operandType)) { + returnTypes.push_back(operandType); + return mlir::success(); + } - if (auto shapedType = operandType.dyn_cast()) { - returnTypes.push_back(shapedType); - return mlir::success(); - } + if (auto shapedType = operandType.dyn_cast()) { + returnTypes.push_back(shapedType); + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool NegateOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool NegateOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return true; +} - return true; +mlir::OpFoldResult NegateOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); + + if (!operand) { + return {}; } - mlir::OpFoldResult NegateOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); + auto resultType = getResult().getType(); - if (!operand) { - return {}; + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + return getAttr(resultType, -1 * getScalarIntegerLikeValue(operand)); } - auto resultType = getResult().getType(); + if (isScalarFloatLike(operand)) { + return getAttr(resultType, -1 * getScalarFloatLikeValue(operand)); + } + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - return getAttr(resultType, -1 * getScalarIntegerLikeValue(operand)); - } + return {}; +} - if (isScalarFloatLike(operand)) { - return getAttr(resultType, -1 * getScalarFloatLikeValue(operand)); +mlir::LogicalResult +NegateOp::distribute(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value operand = getOperand(); + mlir::Operation *operandOp = operand.getDefiningOp(); + + if (operandOp) { + if (auto negDistributionInt = + mlir::dyn_cast(operandOp)) { + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(results, builder))) { + return mlir::success(); } } - - return {}; } - mlir::LogicalResult NegateOp::distribute( - llvm::SmallVectorImpl& results, mlir::OpBuilder& builder) - { - mlir::Value operand = getOperand(); - mlir::Operation* operandOp = operand.getDefiningOp(); + // The operation can't be propagated because the child doesn't know how to + // distribute the negation to its children. + results.push_back(getResult()); + return mlir::failure(); +} - if (operandOp) { - if (auto negDistributionInt = - mlir::dyn_cast(operandOp)) { - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - results, builder))) { - return mlir::success(); - } +mlir::LogicalResult +NegateOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value operand = getOperand(); + bool operandDistributed = false; + mlir::Operation *operandOp = operand.getDefiningOp(); + + if (operandOp) { + if (auto negDistributionInt = + mlir::dyn_cast(operandOp)) { + llvm::SmallVector childResults; + + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + operand = childResults[0]; + operandDistributed = true; } } + } - // The operation can't be propagated because the child doesn't know how to - // distribute the negation to its children. - results.push_back(getResult()); - return mlir::failure(); + if (!operandDistributed) { + auto newOperandOp = builder.create(getLoc(), operand); + operand = newOperandOp.getResult(); } - mlir::LogicalResult NegateOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value operand = getOperand(); - bool operandDistributed = false; - mlir::Operation* operandOp = operand.getDefiningOp(); - - if (operandOp) { - if (auto negDistributionInt = - mlir::dyn_cast(operandOp)) { - llvm::SmallVector childResults; - - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - operand = childResults[0]; - operandDistributed = true; - } - } - } + auto resultOp = builder.create(getLoc(), operand); + results.push_back(resultOp.getResult()); - if (!operandDistributed) { - auto newOperandOp = builder.create(getLoc(), operand); - operand = newOperandOp.getResult(); - } + return mlir::success(); +} - auto resultOp = builder.create(getLoc(), operand); - results.push_back(resultOp.getResult()); +mlir::LogicalResult +NegateOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value operand = getOperand(); + bool operandDistributed = false; + mlir::Operation *operandOp = operand.getDefiningOp(); - return mlir::success(); - } + if (operandOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(operandOp)) { + llvm::SmallVector childResults; - mlir::LogicalResult NegateOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value operand = getOperand(); - bool operandDistributed = false; - mlir::Operation* operandOp = operand.getDefiningOp(); - - if (operandOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(operandOp)) { - llvm::SmallVector childResults; - - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - operand = childResults[0]; - operandDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + operand = childResults[0]; + operandDistributed = true; } } + } - if (!operandDistributed) { - auto newOperandOp = builder.create(getLoc(), operand, value); - operand = newOperandOp.getResult(); - } + if (!operandDistributed) { + auto newOperandOp = builder.create(getLoc(), operand, value); + operand = newOperandOp.getResult(); + } - auto resultOp = builder.create(getLoc(), operand); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), operand); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult NegateOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value operand = getOperand(); - bool operandDistributed = false; - mlir::Operation* operandOp = operand.getDefiningOp(); - - if (operandOp) { - if (auto divDistributionInt = - mlir::dyn_cast(operandOp)) { - llvm::SmallVector childResults; - - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - operand = childResults[0]; - operandDistributed = true; - } +mlir::LogicalResult +NegateOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value operand = getOperand(); + bool operandDistributed = false; + mlir::Operation *operandOp = operand.getDefiningOp(); + + if (operandOp) { + if (auto divDistributionInt = + mlir::dyn_cast(operandOp)) { + llvm::SmallVector childResults; + + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + operand = childResults[0]; + operandDistributed = true; } } + } - if (!operandDistributed) { - auto newOperandOp = builder.create(getLoc(), operand, value); - operand = newOperandOp.getResult(); - } + if (!operandDistributed) { + auto newOperandOp = builder.create(getLoc(), operand, value); + operand = newOperandOp.getResult(); + } - auto resultOp = builder.create(getLoc(), operand); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), operand); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AddOp -namespace -{ - struct AddOpRangeOrderingPattern : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; +namespace { +struct AddOpRangeOrderingPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult match(AddOp op) const override - { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); + mlir::LogicalResult match(AddOp op) const override { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); - return mlir::LogicalResult::success( - !lhs.getType().isa() && - rhs.getType().isa()); - } + return mlir::LogicalResult::success(!lhs.getType().isa() && + rhs.getType().isa()); + } - void rewrite( - AddOp op, mlir::PatternRewriter& rewriter) const override - { - // Swap the operands. - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getRhs(), op.getLhs()); - } - }; -} + void rewrite(AddOp op, mlir::PatternRewriter &rewriter) const override { + // Swap the operands. + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.getRhs(), op.getLhs()); + } +}; +} // namespace -namespace mlir::bmodelica -{ - mlir::LogicalResult AddOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); - - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::LogicalResult AddOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); + } + } - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - return mlir::success(); - } + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } + return mlir::success(); + } - auto lhsRangeType = lhsType.dyn_cast(); - auto rhsRangeType = rhsType.dyn_cast(); + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - if (isScalar(lhsType) && rhsRangeType) { - mlir::Type inductionType = - getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); + auto lhsRangeType = lhsType.dyn_cast(); + auto rhsRangeType = rhsType.dyn_cast(); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + if (isScalar(lhsType) && rhsRangeType) { + mlir::Type inductionType = + getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); - if (lhsRangeType && isScalar(rhsType)) { - mlir::Type inductionType = - getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + if (lhsRangeType && isScalar(rhsType)) { + mlir::Type inductionType = + getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); - return mlir::failure(); + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } + + return mlir::failure(); +} + +bool AddOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool AddOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return true; +} - return true; +mlir::OpFoldResult AddOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (!lhs || !rhs) { + return {}; } - mlir::OpFoldResult AddOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + auto resultType = getResult().getType(); - if (!lhs || !rhs) { - return {}; + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); } - auto resultType = getResult().getType(); - - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } - - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue + rhsValue); } + } - if (auto lhsRange = lhs.dyn_cast(); - lhsRange && isScalar(rhs)) { - if (isScalarIntegerLike(rhs)) { - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - int64_t lowerBound = lhsRange.getLowerBound() + rhsValue; - int64_t upperBound = lhsRange.getUpperBound() + rhsValue; - int64_t step = lhsRange.getStep(); + if (auto lhsRange = lhs.dyn_cast(); + lhsRange && isScalar(rhs)) { + if (isScalarIntegerLike(rhs)) { + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + int64_t lowerBound = lhsRange.getLowerBound() + rhsValue; + int64_t upperBound = lhsRange.getUpperBound() + rhsValue; + int64_t step = lhsRange.getStep(); - return IntegerRangeAttr::get( - getContext(), lhsRange.getType(), - lowerBound, upperBound, step); - } + return IntegerRangeAttr::get(getContext(), lhsRange.getType(), lowerBound, + upperBound, step); + } - if (isScalarFloatLike(rhs)) { - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(rhs)) { + double rhsValue = getScalarFloatLikeValue(rhs); - double lowerBound = - static_cast(lhsRange.getLowerBound()) + rhsValue; + double lowerBound = + static_cast(lhsRange.getLowerBound()) + rhsValue; - double upperBound = - static_cast(lhsRange.getUpperBound()) + rhsValue; + double upperBound = + static_cast(lhsRange.getUpperBound()) + rhsValue; - auto step = static_cast(lhsRange.getStep()); + auto step = static_cast(lhsRange.getStep()); - return RealRangeAttr::get( - getContext(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(getContext(), lowerBound, upperBound, step); } + } - if (auto lhsRange = lhs.dyn_cast(); - lhsRange && isScalar(rhs)) { - if (isScalarIntegerLike(rhs)) { - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + if (auto lhsRange = lhs.dyn_cast(); + lhsRange && isScalar(rhs)) { + if (isScalarIntegerLike(rhs)) { + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - double lowerBound = - lhsRange.getLowerBound().convertToDouble() + rhsValue; + double lowerBound = lhsRange.getLowerBound().convertToDouble() + rhsValue; - double upperBound = - lhsRange.getUpperBound().convertToDouble() + rhsValue; + double upperBound = lhsRange.getUpperBound().convertToDouble() + rhsValue; - double step = lhsRange.getStep().convertToDouble(); + double step = lhsRange.getStep().convertToDouble(); - return RealRangeAttr::get( - lhsRange.getType(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(lhsRange.getType(), lowerBound, upperBound, + step); + } - if (isScalarFloatLike(rhs)) { - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(rhs)) { + double rhsValue = getScalarFloatLikeValue(rhs); - double lowerBound = - lhsRange.getLowerBound().convertToDouble() + rhsValue; + double lowerBound = lhsRange.getLowerBound().convertToDouble() + rhsValue; - double upperBound = - lhsRange.getUpperBound().convertToDouble() + rhsValue; + double upperBound = lhsRange.getUpperBound().convertToDouble() + rhsValue; - double step = lhsRange.getStep().convertToDouble(); + double step = lhsRange.getStep().convertToDouble(); - return RealRangeAttr::get( - getContext(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(getContext(), lowerBound, upperBound, step); } - - return {}; } - void AddOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); - } + return {}; +} - mlir::LogicalResult AddOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +void AddOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); +mlir::LogicalResult +AddOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - bool lhsDistributed = false; - bool rhsDistributed = false; + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; + + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult AddOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +AddOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult AddOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +AddOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AddEWOp -namespace mlir::bmodelica -{ - mlir::LogicalResult AddEWOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult AddEWOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsShapedType) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); + if (isScalar(lhsType) && rhsShapedType) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(rhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && isScalar(rhsType)) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); + if (lhsShapedType && isScalar(rhsType)) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(lhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool AddEWOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool AddEWOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult AddEWOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult AddEWOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue + rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue + rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue + rhsValue); + } } - mlir::LogicalResult AddEWOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + return {}; +} + +mlir::LogicalResult +AddEWOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto negDistributionOp = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionOp = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionOp.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded( + negDistributionOp.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto negDistributionOp = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionOp = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionOp.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded( + negDistributionOp.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult AddEWOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +AddEWOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult AddEWOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +AddEWOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SubOp -namespace mlir::bmodelica -{ - mlir::LogicalResult SubOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); - - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } - - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; +namespace mlir::bmodelica { +mlir::LogicalResult SubOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); + + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); - - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); - - return mlir::success(); + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); + } } - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - auto lhsRangeType = lhsType.dyn_cast(); - auto rhsRangeType = rhsType.dyn_cast(); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - if (isScalar(lhsType) && rhsRangeType) { - mlir::Type inductionType = - getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); + return mlir::success(); + } - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - if (lhsRangeType && isScalar(rhsType)) { - mlir::Type inductionType = - getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); + auto lhsRangeType = lhsType.dyn_cast(); + auto rhsRangeType = rhsType.dyn_cast(); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + if (isScalar(lhsType) && rhsRangeType) { + mlir::Type inductionType = + getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); - return mlir::failure(); + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); } - bool SubOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + if (lhsRangeType && isScalar(rhsType)) { + mlir::Type inductionType = + getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); + + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } + + return mlir::failure(); +} + +bool SubOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return true; +} - return true; +mlir::OpFoldResult SubOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (!lhs || !rhs) { + return {}; } - mlir::OpFoldResult SubOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + auto resultType = getResult().getType(); - if (!lhs || !rhs) { - return {}; + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); } - auto resultType = getResult().getType(); - - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } - - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue - rhsValue); } + } - if (auto lhsRange = lhs.dyn_cast(); - lhsRange && isScalar(rhs)) { - if (isScalarIntegerLike(rhs)) { - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - int64_t lowerBound = lhsRange.getLowerBound() - rhsValue; - int64_t upperBound = lhsRange.getUpperBound() - rhsValue; - int64_t step = lhsRange.getStep(); + if (auto lhsRange = lhs.dyn_cast(); + lhsRange && isScalar(rhs)) { + if (isScalarIntegerLike(rhs)) { + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + int64_t lowerBound = lhsRange.getLowerBound() - rhsValue; + int64_t upperBound = lhsRange.getUpperBound() - rhsValue; + int64_t step = lhsRange.getStep(); - return IntegerRangeAttr::get( - getContext(), lhsRange.getType(), - lowerBound, upperBound, step); - } + return IntegerRangeAttr::get(getContext(), lhsRange.getType(), lowerBound, + upperBound, step); + } - if (isScalarFloatLike(rhs)) { - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(rhs)) { + double rhsValue = getScalarFloatLikeValue(rhs); - double lowerBound = - static_cast(lhsRange.getLowerBound()) - rhsValue; + double lowerBound = + static_cast(lhsRange.getLowerBound()) - rhsValue; - double upperBound = - static_cast(lhsRange.getUpperBound()) - rhsValue; + double upperBound = + static_cast(lhsRange.getUpperBound()) - rhsValue; - auto step = static_cast(lhsRange.getStep()); + auto step = static_cast(lhsRange.getStep()); - return RealRangeAttr::get( - getContext(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(getContext(), lowerBound, upperBound, step); } + } - if (auto lhsRange = lhs.dyn_cast(); - lhsRange && isScalar(rhs)) { - if (isScalarIntegerLike(rhs)) { - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + if (auto lhsRange = lhs.dyn_cast(); + lhsRange && isScalar(rhs)) { + if (isScalarIntegerLike(rhs)) { + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - double lowerBound = - lhsRange.getLowerBound().convertToDouble() - rhsValue; + double lowerBound = lhsRange.getLowerBound().convertToDouble() - rhsValue; - double upperBound = - lhsRange.getUpperBound().convertToDouble() - rhsValue; + double upperBound = lhsRange.getUpperBound().convertToDouble() - rhsValue; - double step = lhsRange.getStep().convertToDouble(); + double step = lhsRange.getStep().convertToDouble(); - return RealRangeAttr::get( - lhsRange.getType(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(lhsRange.getType(), lowerBound, upperBound, + step); + } - if (isScalarFloatLike(rhs)) { - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(rhs)) { + double rhsValue = getScalarFloatLikeValue(rhs); - double lowerBound = - lhsRange.getLowerBound().convertToDouble() - rhsValue; + double lowerBound = lhsRange.getLowerBound().convertToDouble() - rhsValue; - double upperBound = - lhsRange.getUpperBound().convertToDouble() - rhsValue; + double upperBound = lhsRange.getUpperBound().convertToDouble() - rhsValue; - double step = lhsRange.getStep().convertToDouble(); + double step = lhsRange.getStep().convertToDouble(); - return RealRangeAttr::get( - getContext(), lowerBound, upperBound, step); - } + return RealRangeAttr::get(getContext(), lowerBound, upperBound, step); } - - return {}; } - mlir::LogicalResult SubOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + return {}; +} - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); +mlir::LogicalResult +SubOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - bool lhsDistributed = false; - bool rhsDistributed = false; + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; + + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult SubOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +SubOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult SubOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +SubOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SubEWOp -namespace mlir::bmodelica -{ - mlir::LogicalResult SubEWOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult SubEWOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsShapedType) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); + if (isScalar(lhsType) && rhsShapedType) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(rhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && isScalar(rhsType)) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); + if (lhsShapedType && isScalar(rhsType)) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(lhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool SubEWOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool SubEWOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult SubEWOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult SubEWOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue - rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue - rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue - rhsValue); + } } - mlir::LogicalResult SubEWOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + return {}; +} + +mlir::LogicalResult +SubEWOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult SubEWOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +SubEWOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult SubEWOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +SubEWOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - bool lhsDistributed = false; - bool rhsDistributed = false; + bool lhsDistributed = false; + bool rhsDistributed = false; - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; - lhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; + lhsDistributed = true; } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; - rhsDistributed = true; - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; + rhsDistributed = true; } } + } - if (!lhsDistributed) { - auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); - lhs = newLhsOp.getResult(); - } + if (!lhsDistributed) { + auto newLhsOp = builder.create(lhs.getLoc(), lhs, value); + lhs = newLhsOp.getResult(); + } - if (!rhsDistributed) { - auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); - rhs = newRhsOp.getResult(); - } + if (!rhsDistributed) { + auto newRhsOp = builder.create(rhs.getLoc(), rhs, value); + rhs = newRhsOp.getResult(); + } - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // MulOp -namespace mlir::bmodelica -{ - mlir::LogicalResult MulOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult MulOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsShapedType) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); + if (isScalar(lhsType) && rhsShapedType) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(rhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && rhsShapedType) { - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + if (lhsShapedType && rhsShapedType) { + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - if (lhsShapedType.getRank() == 1 && rhsShapedType.getRank() == 1) { - returnTypes.push_back(resultElementType); - return mlir::success(); - } + if (lhsShapedType.getRank() == 1 && rhsShapedType.getRank() == 1) { + returnTypes.push_back(resultElementType); + return mlir::success(); + } - if (lhsShapedType.getRank() == 1 && rhsShapedType.getRank() == 2) { - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape()[1], resultElementType)); + if (lhsShapedType.getRank() == 1 && rhsShapedType.getRank() == 2) { + returnTypes.push_back(mlir::RankedTensorType::get( + rhsShapedType.getShape()[1], resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType.getRank() == 2 && rhsShapedType.getRank() == 1) { - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape()[0], resultElementType)); + if (lhsShapedType.getRank() == 2 && rhsShapedType.getRank() == 1) { + returnTypes.push_back(mlir::RankedTensorType::get( + lhsShapedType.getShape()[0], resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType.getRank() == 2 && rhsShapedType.getRank() == 2) { - llvm::SmallVector shape; - shape.push_back(lhsShapedType.getShape()[0]); - shape.push_back(rhsShapedType.getShape()[1]); + if (lhsShapedType.getRank() == 2 && rhsShapedType.getRank() == 2) { + llvm::SmallVector shape; + shape.push_back(lhsShapedType.getShape()[0]); + shape.push_back(rhsShapedType.getShape()[1]); - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::success(); - } + return mlir::success(); } + } - auto lhsRangeType = lhsType.dyn_cast(); - auto rhsRangeType = rhsType.dyn_cast(); + auto lhsRangeType = lhsType.dyn_cast(); + auto rhsRangeType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsRangeType) { - mlir::Type inductionType = - getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); + if (isScalar(lhsType) && rhsRangeType) { + mlir::Type inductionType = + getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } - if (lhsRangeType && isScalar(rhsType)) { - mlir::Type inductionType = - getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); + if (lhsRangeType && isScalar(rhsType)) { + mlir::Type inductionType = + getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool MulOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool MulOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult MulOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - auto resultType = getResult().getType(); +mlir::OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - if (lhs && isScalar(lhs) && getScalarAttributeValue(lhs) == 0) { - if (!resultType.isa()) { - return getAttr(resultType, static_cast(0)); - } - } + auto resultType = getResult().getType(); - if (rhs && isScalar(rhs) && getScalarAttributeValue(rhs) == 0) { - if (!resultType.isa()) { - return getAttr(resultType, static_cast(0)); - } + if (lhs && isScalar(lhs) && getScalarAttributeValue(lhs) == 0) { + if (!resultType.isa()) { + return getAttr(resultType, static_cast(0)); } + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + if (rhs && isScalar(rhs) && getScalarAttributeValue(rhs) == 0) { + if (!resultType.isa()) { + return getAttr(resultType, static_cast(0)); + } + } - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue * rhsValue); + } } - mlir::LogicalResult MulOp::distribute( - llvm::SmallVectorImpl& results, mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + return {}; +} - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); +mlir::LogicalResult +MulOp::distribute(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - mlir::Value toDistribute = rhs; - results.clear(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + mlir::Value toDistribute = rhs; + results.clear(); + + if (mlir::succeeded(mulDistributionInt.distributeMulOp(results, builder, + toDistribute))) { + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - mlir::Value toDistribute = lhs; - results.clear(); + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + mlir::Value toDistribute = lhs; + results.clear(); - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(results, builder, + toDistribute))) { + return mlir::success(); } } - - // The operation can't be propagated because none of the children - // know how to distribute the multiplication to their children. - results.push_back(getResult()); - return mlir::failure(); } - mlir::LogicalResult MulOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + // The operation can't be propagated because none of the children + // know how to distribute the multiplication to their children. + results.push_back(getResult()); + return mlir::failure(); +} + +mlir::LogicalResult +MulOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult MulOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +MulOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult MulOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +MulOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); - - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - bool MulOp::isScalarProduct() - { - return !getLhs().getType().isa() && - getRhs().getType().isa(); - } + return mlir::success(); +} - bool MulOp::isCrossProduct() - { - auto lhsTensorType = getLhs().getType().dyn_cast(); - auto rhsTensorType = getRhs().getType().dyn_cast(); +bool MulOp::isScalarProduct() { + return !getLhs().getType().isa() && + getRhs().getType().isa(); +} - return lhsTensorType && rhsTensorType && - lhsTensorType.getRank() == 1 && - rhsTensorType.getRank() == 1; - } +bool MulOp::isCrossProduct() { + auto lhsTensorType = getLhs().getType().dyn_cast(); + auto rhsTensorType = getRhs().getType().dyn_cast(); - bool MulOp::isVectorMatrixProduct() - { - auto lhsTensorType = getLhs().getType().dyn_cast(); - auto rhsTensorType = getRhs().getType().dyn_cast(); + return lhsTensorType && rhsTensorType && lhsTensorType.getRank() == 1 && + rhsTensorType.getRank() == 1; +} - if (!lhsTensorType || !rhsTensorType) { - return false; - } +bool MulOp::isVectorMatrixProduct() { + auto lhsTensorType = getLhs().getType().dyn_cast(); + auto rhsTensorType = getRhs().getType().dyn_cast(); - return lhsTensorType.getRank() == 1 && rhsTensorType.getRank() == 2; + if (!lhsTensorType || !rhsTensorType) { + return false; } - bool MulOp::isMatrixVectorProduct() - { - auto lhsTensorType = getLhs().getType().dyn_cast(); - auto rhsTensorType = getRhs().getType().dyn_cast(); + return lhsTensorType.getRank() == 1 && rhsTensorType.getRank() == 2; +} - if (!lhsTensorType || !rhsTensorType) { - return false; - } +bool MulOp::isMatrixVectorProduct() { + auto lhsTensorType = getLhs().getType().dyn_cast(); + auto rhsTensorType = getRhs().getType().dyn_cast(); - return lhsTensorType.getRank() == 2 && rhsTensorType.getRank() == 1; + if (!lhsTensorType || !rhsTensorType) { + return false; } - bool MulOp::isMatrixProduct() - { - auto lhsTensorType = getLhs().getType().dyn_cast(); - auto rhsTensorType = getRhs().getType().dyn_cast(); + return lhsTensorType.getRank() == 2 && rhsTensorType.getRank() == 1; +} - if (!lhsTensorType || !rhsTensorType) { - return false; - } +bool MulOp::isMatrixProduct() { + auto lhsTensorType = getLhs().getType().dyn_cast(); + auto rhsTensorType = getRhs().getType().dyn_cast(); - return lhsTensorType.getRank() == 2 && rhsTensorType.getRank() == 2; + if (!lhsTensorType || !rhsTensorType) { + return false; } + + return lhsTensorType.getRank() == 2 && rhsTensorType.getRank() == 2; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // MulEWOp -namespace mlir::bmodelica -{ - mlir::LogicalResult MulEWOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult MulEWOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsShapedType) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); + if (isScalar(lhsType) && rhsShapedType) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(rhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && isScalar(rhsType)) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); + if (lhsShapedType && isScalar(rhsType)) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(lhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool MulEWOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool MulEWOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult MulEWOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult MulEWOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue * rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue * rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue * rhsValue); + } } - mlir::LogicalResult MulEWOp::distribute( - llvm::SmallVectorImpl& results, mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + return {}; +} - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); +mlir::LogicalResult +MulEWOp::distribute(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - mlir::Value toDistribute = rhs; - results.clear(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + mlir::Value toDistribute = rhs; + results.clear(); + + if (mlir::succeeded(mulDistributionInt.distributeMulOp(results, builder, + toDistribute))) { + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - mlir::Value toDistribute = lhs; - results.clear(); + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + mlir::Value toDistribute = lhs; + results.clear(); - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (mlir::succeeded(mulDistributionInt.distributeMulOp(results, builder, + toDistribute))) { + return mlir::success(); } } - - // The operation can't be propagated because none of the children - // know how to distribute the multiplication to their children. - results.push_back(getResult()); - return mlir::failure(); } - mlir::LogicalResult MulEWOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + // The operation can't be propagated because none of the children + // know how to distribute the multiplication to their children. + results.push_back(getResult()); + return mlir::failure(); +} + +mlir::LogicalResult +MulEWOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult MulEWOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +MulEWOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult MulEWOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +MulEWOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DivOp -namespace mlir::bmodelica -{ - mlir::LogicalResult DivOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult DivOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); - if (lhsShapedType && isScalar(rhsType)) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); + if (lhsShapedType && isScalar(rhsType)) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(lhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - auto lhsRangeType = lhsType.dyn_cast(); - auto rhsRangeType = rhsType.dyn_cast(); + auto lhsRangeType = lhsType.dyn_cast(); + auto rhsRangeType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsRangeType) { - mlir::Type inductionType = - getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); + if (isScalar(lhsType) && rhsRangeType) { + mlir::Type inductionType = + getMostGenericScalarType(lhsType, rhsRangeType.getInductionType()); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } - if (lhsRangeType && isScalar(rhsType)) { - mlir::Type inductionType = - getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); + if (lhsRangeType && isScalar(rhsType)) { + mlir::Type inductionType = + getMostGenericScalarType(lhsRangeType.getInductionType(), rhsType); - returnTypes.push_back(RangeType::get(context, inductionType)); - return mlir::success(); - } + returnTypes.push_back(RangeType::get(context, inductionType)); + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool DivOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool DivOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult DivOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult DivOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue / rhsValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue / rhsValue); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue / rhsValue); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue / rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue / rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue / rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue / rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue / rhsValue); + } } - mlir::LogicalResult DivOp::distribute( - llvm::SmallVectorImpl& results, mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); + return {}; +} + +mlir::LogicalResult +DivOp::distribute(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - mlir::Value toDistribute = getRhs(); + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + mlir::Value toDistribute = getRhs(); - if (mlir::succeeded(divDistributionInt.distributeDivOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(results, builder, + toDistribute))) { + return mlir::success(); } } - - // The operation can't be propagated because the dividend does not know - // how to distribute the division to their children. - results.push_back(getResult()); - return mlir::success(); } - mlir::LogicalResult DivOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + // The operation can't be propagated because the dividend does not know + // how to distribute the division to their children. + results.push_back(getResult()); + return mlir::success(); +} + +mlir::LogicalResult +DivOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult DivOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +DivOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult DivOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +DivOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = mlir::dyn_cast( - lhs.getDefiningOp())) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhs.getDefiningOp())) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = mlir::dyn_cast( - rhs.getDefiningOp())) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhs.getDefiningOp())) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DivEWOp -namespace mlir::bmodelica -{ - mlir::LogicalResult DivEWOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult DivEWOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(getMostGenericScalarType(lhsType, rhsType)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (isScalar(lhsType) && rhsShapedType) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); + if (isScalar(lhsType) && rhsShapedType) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsType, rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - rhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(rhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && isScalar(rhsType)) { - mlir::Type resultElementType = - getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); + if (lhsShapedType && isScalar(rhsType)) { + mlir::Type resultElementType = + getMostGenericScalarType(lhsShapedType.getElementType(), rhsType); - returnTypes.push_back(mlir::RankedTensorType::get( - lhsShapedType.getShape(), resultElementType)); + returnTypes.push_back(mlir::RankedTensorType::get(lhsShapedType.getShape(), + resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - mlir::Type resultElementType = getMostGenericScalarType( - lhsShapedType.getElementType(), rhsShapedType.getElementType()); + mlir::Type resultElementType = getMostGenericScalarType( + lhsShapedType.getElementType(), rhsShapedType.getElementType()); - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool DivEWOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool DivEWOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return true; +} - return true; +mlir::OpFoldResult DivEWOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (!lhs || !rhs) { + return {}; } - mlir::OpFoldResult DivEWOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + auto resultType = getResult().getType(); - if (!lhs || !rhs) { - return {}; + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue / rhsValue); } - auto resultType = getResult().getType(); - - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue / rhsValue); - } - - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue / rhsValue); - } - - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, lhsValue / rhsValue); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue / rhsValue); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, lhsValue / rhsValue); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, lhsValue / rhsValue); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, lhsValue / rhsValue); + } } - mlir::LogicalResult DivEWOp::distribute( - llvm::SmallVectorImpl& results, mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); + return {}; +} + +mlir::LogicalResult +DivEWOp::distribute(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - mlir::Value toDistribute = getRhs(); + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + mlir::Value toDistribute = getRhs(); - if (mlir::succeeded(divDistributionInt.distributeDivOp( - results, builder, toDistribute))) { - return mlir::success(); - } + if (mlir::succeeded(divDistributionInt.distributeDivOp(results, builder, + toDistribute))) { + return mlir::success(); } } - - // The operation can't be propagated because the dividend does not know - // how to distribute the division to their children. - results.push_back(getResult()); - return mlir::failure(); } - mlir::LogicalResult DivEWOp::distributeNegateOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); + // The operation can't be propagated because the dividend does not know + // how to distribute the division to their children. + results.push_back(getResult()); + return mlir::failure(); +} + +mlir::LogicalResult +DivEWOp::distributeNegateOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto negDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto negDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(negDistributionInt.distributeNegateOp( - childResults, builder)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded( + negDistributionInt.distributeNegateOp(childResults, builder)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult DivEWOp::distributeMulOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +DivEWOp::distributeMulOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); +} - mlir::LogicalResult DivEWOp::distributeDivOp( - llvm::SmallVectorImpl& results, - mlir::OpBuilder& builder, - mlir::Value value) - { - mlir::Value lhs = getLhs(); - mlir::Value rhs = getRhs(); +mlir::LogicalResult +DivEWOp::distributeDivOp(llvm::SmallVectorImpl &results, + mlir::OpBuilder &builder, mlir::Value value) { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); - mlir::Operation* lhsOp = lhs.getDefiningOp(); - mlir::Operation* rhsOp = rhs.getDefiningOp(); + mlir::Operation *lhsOp = lhs.getDefiningOp(); + mlir::Operation *rhsOp = rhs.getDefiningOp(); - if (lhsOp) { - if (auto divDistributionInt = - mlir::dyn_cast(lhsOp)) { - llvm::SmallVector childResults; + if (lhsOp) { + if (auto divDistributionInt = + mlir::dyn_cast(lhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(divDistributionInt.distributeDivOp( - childResults, builder, value)) - && childResults.size() == 1) { - lhs = childResults[0]; + if (mlir::succeeded(divDistributionInt.distributeDivOp(childResults, + builder, value)) && + childResults.size() == 1) { + lhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - if (rhsOp) { - if (auto mulDistributionInt = - mlir::dyn_cast(rhsOp)) { - llvm::SmallVector childResults; + if (rhsOp) { + if (auto mulDistributionInt = + mlir::dyn_cast(rhsOp)) { + llvm::SmallVector childResults; - if (mlir::succeeded(mulDistributionInt.distributeMulOp( - childResults, builder, value)) - && childResults.size() == 1) { - rhs = childResults[0]; + if (mlir::succeeded(mulDistributionInt.distributeMulOp(childResults, + builder, value)) && + childResults.size() == 1) { + rhs = childResults[0]; - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); - return mlir::success(); - } + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); + return mlir::success(); } } + } - auto lhsNewOp = builder.create(getLoc(), lhs, value); - lhs = lhsNewOp.getResult(); + auto lhsNewOp = builder.create(getLoc(), lhs, value); + lhs = lhsNewOp.getResult(); - auto resultOp = builder.create(getLoc(), lhs, rhs); - results.push_back(resultOp.getResult()); + auto resultOp = builder.create(getLoc(), lhs, rhs); + results.push_back(resultOp.getResult()); - return mlir::success(); - } + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // PowOp -namespace mlir::bmodelica -{ - mlir::LogicalResult PowOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type baseType = adaptor.getBase().getType(); - mlir::Type exponentType = adaptor.getExponent().getType(); - - if (isScalar(baseType)) { - if (exponentType.isa()) { - returnTypes.push_back(exponentType); - return mlir::success(); - } - - returnTypes.push_back(baseType); +namespace mlir::bmodelica { +mlir::LogicalResult PowOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type baseType = adaptor.getBase().getType(); + mlir::Type exponentType = adaptor.getExponent().getType(); + + if (isScalar(baseType)) { + if (exponentType.isa()) { + returnTypes.push_back(exponentType); return mlir::success(); } - if (auto baseShapedType = baseType.dyn_cast()) { - if (!isScalar(exponentType)) { - return mlir::failure(); - } + returnTypes.push_back(baseType); + return mlir::success(); + } - returnTypes.push_back(baseShapedType); - return mlir::success(); + if (auto baseShapedType = baseType.dyn_cast()) { + if (!isScalar(exponentType)) { + return mlir::failure(); } - return mlir::failure(); + returnTypes.push_back(baseShapedType); + return mlir::success(); } - bool PowOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { - return false; - } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return mlir::failure(); +} - return true; +bool PowOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - mlir::OpFoldResult PowOp::fold(FoldAdaptor adaptor) - { - auto base = adaptor.getBase(); - auto exponent = adaptor.getExponent(); - - if (!base || !exponent) { - return {}; + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { + return false; } + } - auto resultType = getResult().getType(); + return true; +} - if (isScalar(base) && isScalar(exponent)) { - if (isScalarIntegerLike(base) && isScalarIntegerLike(exponent)) { - auto baseValue = static_cast(getScalarIntegerLikeValue(base)); +mlir::OpFoldResult PowOp::fold(FoldAdaptor adaptor) { + auto base = adaptor.getBase(); + auto exponent = adaptor.getExponent(); - auto exponentValue = - static_cast(getScalarIntegerLikeValue(exponent)); + if (!base || !exponent) { + return {}; + } - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(base) && isScalarFloatLike(exponent)) { - double baseValue = getScalarFloatLikeValue(base); - double exponentValue = getScalarFloatLikeValue(exponent); - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + if (isScalar(base) && isScalar(exponent)) { + if (isScalarIntegerLike(base) && isScalarIntegerLike(exponent)) { + auto baseValue = static_cast(getScalarIntegerLikeValue(base)); - if (isScalarIntegerLike(base) && isScalarFloatLike(exponent)) { - auto baseValue = static_cast(getScalarIntegerLikeValue(base)); - double exponentValue = getScalarFloatLikeValue(exponent); - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + auto exponentValue = + static_cast(getScalarIntegerLikeValue(exponent)); - if (isScalarFloatLike(base) && isScalarIntegerLike(exponent)) { - double baseValue = getScalarFloatLikeValue(base); + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } - auto exponentValue = - static_cast(getScalarIntegerLikeValue(exponent)); + if (isScalarFloatLike(base) && isScalarFloatLike(exponent)) { + double baseValue = getScalarFloatLikeValue(base); + double exponentValue = getScalarFloatLikeValue(exponent); + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + if (isScalarIntegerLike(base) && isScalarFloatLike(exponent)) { + auto baseValue = static_cast(getScalarIntegerLikeValue(base)); + double exponentValue = getScalarFloatLikeValue(exponent); + return getAttr(resultType, std::pow(baseValue, exponentValue)); } - return {}; + if (isScalarFloatLike(base) && isScalarIntegerLike(exponent)) { + double baseValue = getScalarFloatLikeValue(base); + + auto exponentValue = + static_cast(getScalarIntegerLikeValue(exponent)); + + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // PowEWOp -namespace mlir::bmodelica -{ - mlir::LogicalResult PowEWOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type baseType = adaptor.getBase().getType(); - mlir::Type exponentType = adaptor.getExponent().getType(); - - auto inferResultType = - [](mlir::Type baseType, mlir::Type exponentType) -> mlir::Type { - if (exponentType.isa()) { - return exponentType; - } - - return baseType; - }; +namespace mlir::bmodelica { +mlir::LogicalResult PowEWOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type baseType = adaptor.getBase().getType(); + mlir::Type exponentType = adaptor.getExponent().getType(); - if (isScalar(baseType) && isScalar(exponentType)) { - returnTypes.push_back(inferResultType(baseType, exponentType)); - return mlir::success(); + auto inferResultType = [](mlir::Type baseType, + mlir::Type exponentType) -> mlir::Type { + if (exponentType.isa()) { + return exponentType; } - auto baseShapedType = baseType.dyn_cast(); - auto exponentShapedType = exponentType.dyn_cast(); + return baseType; + }; - if (isScalar(baseType) && exponentShapedType) { - returnTypes.push_back(mlir::RankedTensorType::get( - exponentShapedType.getShape(), - inferResultType(baseType, exponentShapedType.getElementType()))); + if (isScalar(baseType) && isScalar(exponentType)) { + returnTypes.push_back(inferResultType(baseType, exponentType)); + return mlir::success(); + } - return mlir::success(); - } + auto baseShapedType = baseType.dyn_cast(); + auto exponentShapedType = exponentType.dyn_cast(); - if (baseShapedType && isScalar(exponentType)) { - returnTypes.push_back(mlir::RankedTensorType::get( - baseShapedType.getShape(), - inferResultType(baseShapedType.getElementType(), exponentType))); + if (isScalar(baseType) && exponentShapedType) { + returnTypes.push_back(mlir::RankedTensorType::get( + exponentShapedType.getShape(), + inferResultType(baseType, exponentShapedType.getElementType()))); - return mlir::success(); - } + return mlir::success(); + } - if (baseShapedType && exponentShapedType) { - if (baseShapedType.getRank() != exponentShapedType.getRank()) { - return mlir::failure(); - } + if (baseShapedType && isScalar(exponentType)) { + returnTypes.push_back(mlir::RankedTensorType::get( + baseShapedType.getShape(), + inferResultType(baseShapedType.getElementType(), exponentType))); + + return mlir::success(); + } - int64_t rank = baseShapedType.getRank(); - llvm::SmallVector shape; + if (baseShapedType && exponentShapedType) { + if (baseShapedType.getRank() != exponentShapedType.getRank()) { + return mlir::failure(); + } - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = baseShapedType.getDimSize(dim); - int64_t rhsDimSize = exponentShapedType.getDimSize(dim); + int64_t rank = baseShapedType.getRank(); + llvm::SmallVector shape; - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = baseShapedType.getDimSize(dim); + int64_t rhsDimSize = exponentShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); } - mlir::Type resultElementType = inferResultType( - baseShapedType.getElementType(), - exponentShapedType.getElementType()); + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); + } + } - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); + mlir::Type resultElementType = inferResultType( + baseShapedType.getElementType(), exponentShapedType.getElementType()); - return mlir::success(); - } + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); - return mlir::failure(); + return mlir::success(); } - bool PowEWOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { - return false; - } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return mlir::failure(); +} - return true; +bool PowEWOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - mlir::OpFoldResult PowEWOp::fold(FoldAdaptor adaptor) - { - auto base = adaptor.getBase(); - auto exponent = adaptor.getExponent(); - - if (!base || !exponent) { - return {}; + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { + return false; } + } - auto resultType = getResult().getType(); + return true; +} - if (isScalar(base) && isScalar(exponent)) { - if (isScalarIntegerLike(base) && isScalarIntegerLike(exponent)) { - auto baseValue = static_cast(getScalarIntegerLikeValue(base)); +mlir::OpFoldResult PowEWOp::fold(FoldAdaptor adaptor) { + auto base = adaptor.getBase(); + auto exponent = adaptor.getExponent(); - auto exponentValue = - static_cast(getScalarIntegerLikeValue(exponent)); + if (!base || !exponent) { + return {}; + } - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(base) && isScalarFloatLike(exponent)) { - double baseValue = getScalarFloatLikeValue(base); - double exponentValue = getScalarFloatLikeValue(exponent); - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + if (isScalar(base) && isScalar(exponent)) { + if (isScalarIntegerLike(base) && isScalarIntegerLike(exponent)) { + auto baseValue = static_cast(getScalarIntegerLikeValue(base)); - if (isScalarIntegerLike(base) && isScalarFloatLike(exponent)) { - auto baseValue = static_cast(getScalarIntegerLikeValue(base)); - double exponentValue = getScalarFloatLikeValue(exponent); - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + auto exponentValue = + static_cast(getScalarIntegerLikeValue(exponent)); - if (isScalarFloatLike(base) && isScalarIntegerLike(exponent)) { - double baseValue = getScalarFloatLikeValue(base); + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } - auto exponentValue = - static_cast(getScalarIntegerLikeValue(exponent)); + if (isScalarFloatLike(base) && isScalarFloatLike(exponent)) { + double baseValue = getScalarFloatLikeValue(base); + double exponentValue = getScalarFloatLikeValue(exponent); + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } - return getAttr(resultType, std::pow(baseValue, exponentValue)); - } + if (isScalarIntegerLike(base) && isScalarFloatLike(exponent)) { + auto baseValue = static_cast(getScalarIntegerLikeValue(base)); + double exponentValue = getScalarFloatLikeValue(exponent); + return getAttr(resultType, std::pow(baseValue, exponentValue)); } - return {}; + if (isScalarFloatLike(base) && isScalarIntegerLike(exponent)) { + double baseValue = getScalarFloatLikeValue(base); + + auto exponentValue = + static_cast(getScalarIntegerLikeValue(exponent)); + + return getAttr(resultType, std::pow(baseValue, exponentValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Comparison operations @@ -5656,452 +5220,405 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // EqOp -namespace mlir::bmodelica -{ - mlir::LogicalResult EqOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult EqOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool EqOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool EqOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult EqOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult EqOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue == rhsValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue == rhsValue)); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue == rhsValue)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue == rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue == rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue == rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue == rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue == rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // NotEqOp -namespace mlir::bmodelica -{ - mlir::LogicalResult NotEqOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult NotEqOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool NotEqOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool NotEqOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult NotEqOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult NotEqOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue != rhsValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue != rhsValue)); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue != rhsValue)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue != rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue != rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue != rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue != rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue != rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // GtOp -namespace mlir::bmodelica -{ - mlir::LogicalResult GtOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult GtOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool GtOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool GtOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult GtOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult GtOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue > rhsValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue > rhsValue)); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue > rhsValue)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue > rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue > rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue > rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue > rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue > rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // GteOp -namespace mlir::bmodelica -{ - mlir::LogicalResult GteOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult GteOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool GteOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool GteOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return true; +} - return true; +mlir::OpFoldResult GteOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (!lhs || !rhs) { + return {}; } - mlir::OpFoldResult GteOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + auto resultType = getResult().getType(); - if (!lhs || !rhs) { - return {}; + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue >= rhsValue)); } - auto resultType = getResult().getType(); - - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue >= rhsValue)); - } - - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue >= rhsValue)); - } - - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue >= rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue >= rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue >= rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue >= rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue >= rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // LtOp -namespace mlir::bmodelica -{ - mlir::LogicalResult LtOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult LtOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool LtOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool LtOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult LtOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult LtOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue < rhsValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue < rhsValue)); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue < rhsValue)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue < rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue < rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue < rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue < rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue < rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // LteOp -namespace mlir::bmodelica -{ - mlir::LogicalResult LteOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); +namespace mlir::bmodelica { +mlir::LogicalResult LteOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); +} + +bool LteOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool LteOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult LteOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); + return true; +} - if (!lhs || !rhs) { - return {}; - } +mlir::OpFoldResult LteOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - auto resultType = getResult().getType(); + if (!lhs || !rhs) { + return {}; + } - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue <= rhsValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue <= rhsValue)); - } + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue <= rhsValue)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); - double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr(resultType, static_cast(lhsValue <= rhsValue)); - } + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue <= rhsValue)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); - return getAttr(resultType, static_cast(lhsValue <= rhsValue)); - } + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + auto lhsValue = static_cast(getScalarIntegerLikeValue(lhs)); + double rhsValue = getScalarFloatLikeValue(rhs); + return getAttr(resultType, static_cast(lhsValue <= rhsValue)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + auto rhsValue = static_cast(getScalarIntegerLikeValue(rhs)); + return getAttr(resultType, static_cast(lhsValue <= rhsValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Logic operations @@ -6110,516 +5627,467 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // NotOp -namespace mlir::bmodelica -{ - mlir::LogicalResult NotOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type operandType = adaptor.getOperand().getType(); - - if (isScalar(operandType)) { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult NotOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type operandType = adaptor.getOperand().getType(); - if (auto shapedType = operandType.dyn_cast()) { - returnTypes.push_back(mlir::RankedTensorType::get( - shapedType.getShape(), BooleanType::get(context))); + if (isScalar(operandType)) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); + } - return mlir::success(); - } + if (auto shapedType = operandType.dyn_cast()) { + returnTypes.push_back(mlir::RankedTensorType::get( + shapedType.getShape(), BooleanType::get(context))); - return mlir::failure(); + return mlir::success(); } - bool NotOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { - return false; - } + return mlir::failure(); +} - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } +bool NotOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; + } - return true; + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { + return false; + } } - mlir::OpFoldResult NotOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); + return true; +} - if (!operand) { - return {}; - } +mlir::OpFoldResult NotOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t operandValue = getScalarIntegerLikeValue(operand); - return getAttr(resultType, static_cast(operandValue == 0)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, static_cast(operandValue == 0)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t operandValue = getScalarIntegerLikeValue(operand); + return getAttr(resultType, static_cast(operandValue == 0)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, static_cast(operandValue == 0)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AndOp -namespace mlir::bmodelica -{ - mlir::LogicalResult AndOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult AndOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - returnTypes.push_back(mlir::RankedTensorType::get( - shape, BooleanType::get(context))); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, BooleanType::get(context))); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool AndOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool AndOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult AndOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - - if (!lhs || !rhs) { - return {}; - } + return true; +} - auto resultType = getResult().getType(); +mlir::OpFoldResult AndOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); + if (!lhs || !rhs) { + return {}; + } - return getAttr( - resultType, - static_cast(lhsValue != 0 && rhsValue != 0)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 && rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 && rhsValue != 0)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 && rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 && rhsValue != 0)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 && rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 && rhsValue != 0)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + + return getAttr(resultType, + static_cast(lhsValue != 0 && rhsValue != 0)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // OrOp -namespace mlir::bmodelica -{ - mlir::LogicalResult OrOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type lhsType = adaptor.getLhs().getType(); - mlir::Type rhsType = adaptor.getRhs().getType(); - - if (isScalar(lhsType) && isScalar(rhsType)) { - returnTypes.push_back(BooleanType::get(context)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult OrOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type lhsType = adaptor.getLhs().getType(); + mlir::Type rhsType = adaptor.getRhs().getType(); + + if (isScalar(lhsType) && isScalar(rhsType)) { + returnTypes.push_back(BooleanType::get(context)); + return mlir::success(); + } - auto lhsShapedType = lhsType.dyn_cast(); - auto rhsShapedType = rhsType.dyn_cast(); + auto lhsShapedType = lhsType.dyn_cast(); + auto rhsShapedType = rhsType.dyn_cast(); - if (lhsShapedType && rhsShapedType) { - if (lhsShapedType.getRank() != rhsShapedType.getRank()) { - return mlir::failure(); - } + if (lhsShapedType && rhsShapedType) { + if (lhsShapedType.getRank() != rhsShapedType.getRank()) { + return mlir::failure(); + } - int64_t rank = lhsShapedType.getRank(); - llvm::SmallVector shape; + int64_t rank = lhsShapedType.getRank(); + llvm::SmallVector shape; - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = lhsShapedType.getDimSize(dim); - int64_t rhsDimSize = rhsShapedType.getDimSize(dim); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = lhsShapedType.getDimSize(dim); + int64_t rhsDimSize = rhsShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); + } - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); } + } - returnTypes.push_back(mlir::RankedTensorType::get( - shape, BooleanType::get(context))); + returnTypes.push_back( + mlir::RankedTensorType::get(shape, BooleanType::get(context))); - return mlir::success(); - } + return mlir::success(); + } - return mlir::failure(); + return mlir::failure(); +} + +bool OrOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool OrOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::OpFoldResult OrOp::fold(FoldAdaptor adaptor) - { - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - - if (!lhs || !rhs) { - return {}; - } + return true; +} - auto resultType = getResult().getType(); +mlir::OpFoldResult OrOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); - if (isScalar(lhs) && isScalar(rhs)) { - if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); + if (!lhs || !rhs) { + return {}; + } - return getAttr( - resultType, - static_cast(lhsValue != 0 || rhsValue != 0)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalar(lhs) && isScalar(rhs)) { + if (isScalarIntegerLike(lhs) && isScalarIntegerLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 || rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 || rhsValue != 0)); + } - if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { - int64_t lhsValue = getScalarIntegerLikeValue(lhs); - double rhsValue = getScalarFloatLikeValue(rhs); + if (isScalarFloatLike(lhs) && isScalarFloatLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 || rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 || rhsValue != 0)); + } - if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { - double lhsValue = getScalarFloatLikeValue(lhs); - int64_t rhsValue = getScalarIntegerLikeValue(rhs); + if (isScalarIntegerLike(lhs) && isScalarFloatLike(rhs)) { + int64_t lhsValue = getScalarIntegerLikeValue(lhs); + double rhsValue = getScalarFloatLikeValue(rhs); - return getAttr( - resultType, - static_cast(lhsValue != 0 || rhsValue != 0)); - } + return getAttr(resultType, + static_cast(lhsValue != 0 || rhsValue != 0)); } - return {}; + if (isScalarFloatLike(lhs) && isScalarIntegerLike(rhs)) { + double lhsValue = getScalarFloatLikeValue(lhs); + int64_t rhsValue = getScalarIntegerLikeValue(rhs); + + return getAttr(resultType, + static_cast(lhsValue != 0 || rhsValue != 0)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SelectOp -namespace mlir::bmodelica -{ - mlir::LogicalResult SelectOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - - if (adaptor.getTrueValues().size() != adaptor.getFalseValues().size()) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::LogicalResult SelectOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); - for (auto [trueValue, falseValue] : llvm::zip( - adaptor.getTrueValues(), adaptor.getFalseValues())) { - mlir::Type trueValueType = trueValue.getType(); - mlir::Type falseValueType = falseValue.getType(); + if (adaptor.getTrueValues().size() != adaptor.getFalseValues().size()) { + return mlir::failure(); + } - if (trueValueType == falseValueType) { - returnTypes.push_back(trueValueType); - } else if (isScalar(trueValueType) && isScalar(falseValueType)) { - returnTypes.push_back( - getMostGenericScalarType(trueValueType, falseValueType)); - } else { - auto trueValueShapedType = - trueValueType.dyn_cast(); + for (auto [trueValue, falseValue] : + llvm::zip(adaptor.getTrueValues(), adaptor.getFalseValues())) { + mlir::Type trueValueType = trueValue.getType(); + mlir::Type falseValueType = falseValue.getType(); - auto falseValueShapedType = - falseValueType.dyn_cast(); + if (trueValueType == falseValueType) { + returnTypes.push_back(trueValueType); + } else if (isScalar(trueValueType) && isScalar(falseValueType)) { + returnTypes.push_back( + getMostGenericScalarType(trueValueType, falseValueType)); + } else { + auto trueValueShapedType = trueValueType.dyn_cast(); - if (trueValueShapedType && falseValueShapedType) { - if (trueValueShapedType.getRank() != falseValueShapedType.getRank()) { - return mlir::failure(); - } + auto falseValueShapedType = falseValueType.dyn_cast(); - int64_t rank = trueValueShapedType.getRank(); - llvm::SmallVector shape; + if (trueValueShapedType && falseValueShapedType) { + if (trueValueShapedType.getRank() != falseValueShapedType.getRank()) { + return mlir::failure(); + } - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t lhsDimSize = trueValueShapedType.getDimSize(dim); - int64_t rhsDimSize = falseValueShapedType.getDimSize(dim); + int64_t rank = trueValueShapedType.getRank(); + llvm::SmallVector shape; - if (lhsDimSize != mlir::ShapedType::kDynamic && - rhsDimSize != mlir::ShapedType::kDynamic && - lhsDimSize != rhsDimSize) { - return mlir::failure(); - } + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t lhsDimSize = trueValueShapedType.getDimSize(dim); + int64_t rhsDimSize = falseValueShapedType.getDimSize(dim); - if (lhsDimSize != mlir::ShapedType::kDynamic) { - shape.push_back(lhsDimSize); - } else { - shape.push_back(rhsDimSize); - } + if (lhsDimSize != mlir::ShapedType::kDynamic && + rhsDimSize != mlir::ShapedType::kDynamic && + lhsDimSize != rhsDimSize) { + return mlir::failure(); } - mlir::Type resultElementType = getMostGenericScalarType( - trueValueShapedType.getElementType(), - falseValueShapedType.getElementType()); - - returnTypes.push_back(mlir::RankedTensorType::get( - shape, resultElementType)); - } else { - return mlir::failure(); + if (lhsDimSize != mlir::ShapedType::kDynamic) { + shape.push_back(lhsDimSize); + } else { + shape.push_back(rhsDimSize); + } } + + mlir::Type resultElementType = + getMostGenericScalarType(trueValueShapedType.getElementType(), + falseValueShapedType.getElementType()); + + returnTypes.push_back( + mlir::RankedTensorType::get(shape, resultElementType)); + } else { + return mlir::failure(); } } + } - return mlir::success(); + return mlir::success(); +} + +bool SelectOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool SelectOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } - mlir::ParseResult SelectOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::OpAsmParser::UnresolvedOperand condition; - mlir::Type conditionType; + return true; +} - llvm::SmallVector trueValues; - llvm::SmallVector trueValuesTypes; +mlir::ParseResult SelectOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand condition; + mlir::Type conditionType; - llvm::SmallVector falseValues; - llvm::SmallVector falseValuesTypes; + llvm::SmallVector trueValues; + llvm::SmallVector trueValuesTypes; - llvm::SmallVector resultTypes; + llvm::SmallVector falseValues; + llvm::SmallVector falseValuesTypes; - if (parser.parseLParen() || - parser.parseOperand(condition) || - parser.parseColonType(conditionType) || - parser.parseRParen() || - parser.resolveOperand(condition, conditionType, result.operands)) { - return mlir::failure(); - } + llvm::SmallVector resultTypes; - if (parser.parseComma()) { - return mlir::failure(); - } + if (parser.parseLParen() || parser.parseOperand(condition) || + parser.parseColonType(conditionType) || parser.parseRParen() || + parser.resolveOperand(condition, conditionType, result.operands)) { + return mlir::failure(); + } - auto trueValuesLoc = parser.getCurrentLocation(); + if (parser.parseComma()) { + return mlir::failure(); + } - if (parser.parseLParen() || - parser.parseOperandList(trueValues) || - parser.parseColonTypeList(trueValuesTypes) || - parser.parseRParen() || - parser.resolveOperands( - trueValues, trueValuesTypes, trueValuesLoc, result.operands)) { - return mlir::failure(); - } + auto trueValuesLoc = parser.getCurrentLocation(); - if (parser.parseComma()) { - return mlir::failure(); - } + if (parser.parseLParen() || parser.parseOperandList(trueValues) || + parser.parseColonTypeList(trueValuesTypes) || parser.parseRParen() || + parser.resolveOperands(trueValues, trueValuesTypes, trueValuesLoc, + result.operands)) { + return mlir::failure(); + } - auto falseValuesLoc = parser.getCurrentLocation(); + if (parser.parseComma()) { + return mlir::failure(); + } - if (parser.parseLParen() || - parser.parseOperandList(falseValues) || - parser.parseColonTypeList(falseValuesTypes) || - parser.parseRParen() || - parser.resolveOperands( - falseValues, falseValuesTypes, falseValuesLoc, result.operands)) { - return mlir::failure(); - } + auto falseValuesLoc = parser.getCurrentLocation(); - if (parser.parseArrowTypeList(resultTypes)) { - return mlir::failure(); - } + if (parser.parseLParen() || parser.parseOperandList(falseValues) || + parser.parseColonTypeList(falseValuesTypes) || parser.parseRParen() || + parser.resolveOperands(falseValues, falseValuesTypes, falseValuesLoc, + result.operands)) { + return mlir::failure(); + } - result.addTypes(resultTypes); - return mlir::success(); + if (parser.parseArrowTypeList(resultTypes)) { + return mlir::failure(); } - void SelectOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer << "(" << getCondition() << " : " - << getCondition().getType() << ")"; + result.addTypes(resultTypes); + return mlir::success(); +} + +void SelectOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer << "(" << getCondition() << " : " << getCondition().getType() << ")"; - printer << ", "; + printer << ", "; - printer << "(" << getTrueValues() << " : " - << getTrueValues().getTypes() << ")"; + printer << "(" << getTrueValues() << " : " << getTrueValues().getTypes() + << ")"; - printer << ", "; + printer << ", "; - printer << "(" << getFalseValues() << " : " - << getFalseValues().getTypes() << ")"; + printer << "(" << getFalseValues() << " : " << getFalseValues().getTypes() + << ")"; - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " -> "; + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " -> "; - auto resultTypes = getResultTypes(); + auto resultTypes = getResultTypes(); - if (resultTypes.size() == 1) { - printer << resultTypes; - } else { - printer << "(" << resultTypes << ")"; - } + if (resultTypes.size() == 1) { + printer << resultTypes; + } else { + printer << "(" << resultTypes << ")"; } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Built-in operations @@ -6628,1380 +6096,1275 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AbsOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult AbsOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult AbsOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t operandValue = getScalarIntegerLikeValue(operand); - return getAttr(resultType, std::abs(operandValue)); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::abs(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t operandValue = getScalarIntegerLikeValue(operand); + return getAttr(resultType, std::abs(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::abs(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AcosOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult AcosOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult AcosOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } + if (!operand) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - return getAttr(resultType, std::acos(operandValue)); - } - - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::acos(operandValue)); - } + return getAttr(resultType, std::acos(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::acos(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AsinOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult AsinOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult AsinOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } - - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::asin(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::asin(operandValue)); - } + return getAttr(resultType, std::asin(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::asin(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // AtanOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult AtanOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult AtanOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::atan(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::atan(operandValue)); - } + return getAttr(resultType, std::atan(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::atan(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // CeilOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult CeilOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult CeilOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } - - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t operandValue = getScalarIntegerLikeValue(operand); - return getAttr(resultType, operandValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::ceil(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t operandValue = getScalarIntegerLikeValue(operand); + return getAttr(resultType, operandValue); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::ceil(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // CosOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult CosOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult CosOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } - - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::cos(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::cos(operandValue)); - } + return getAttr(resultType, std::cos(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::cos(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // CoshOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult CoshOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult CoshOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } + if (!operand) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - return getAttr(resultType, std::cosh(operandValue)); - } - - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::cosh(operandValue)); - } + return getAttr(resultType, std::cosh(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::cosh(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DivTruncOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult DivTruncOp::fold(FoldAdaptor adaptor) - { - auto x = adaptor.getX(); - auto y = adaptor.getY(); - - if (!x || !y) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult DivTruncOp::fold(FoldAdaptor adaptor) { + auto x = adaptor.getX(); + auto y = adaptor.getY(); - auto resultType = getResult().getType(); + if (!x || !y) { + return {}; + } - if (isScalar(x) && isScalar(y)) { - if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { - auto xValue = static_cast(getScalarIntegerLikeValue(x)); - auto yValue = static_cast(getScalarIntegerLikeValue(y)); - return getAttr(resultType, xValue / yValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(x) && isScalarFloatLike(y)) { - double xValue = getScalarFloatLikeValue(x); - double yValue = getScalarFloatLikeValue(y); - return getAttr(resultType, std::trunc(xValue / yValue)); - } + if (isScalar(x) && isScalar(y)) { + if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { + auto xValue = static_cast(getScalarIntegerLikeValue(x)); + auto yValue = static_cast(getScalarIntegerLikeValue(y)); + return getAttr(resultType, xValue / yValue); + } - if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { - auto xValue = static_cast(getScalarIntegerLikeValue(x)); - double yValue = getScalarFloatLikeValue(y); - return getAttr(resultType, std::trunc(xValue / yValue)); - } + if (isScalarFloatLike(x) && isScalarFloatLike(y)) { + double xValue = getScalarFloatLikeValue(x); + double yValue = getScalarFloatLikeValue(y); + return getAttr(resultType, std::trunc(xValue / yValue)); + } - if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { - double xValue = getScalarFloatLikeValue(x); - auto yValue = static_cast(getScalarIntegerLikeValue(y)); - return getAttr(resultType, std::trunc(xValue / yValue)); - } + if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { + auto xValue = static_cast(getScalarIntegerLikeValue(x)); + double yValue = getScalarFloatLikeValue(y); + return getAttr(resultType, std::trunc(xValue / yValue)); } - return {}; + if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { + double xValue = getScalarFloatLikeValue(x); + auto yValue = static_cast(getScalarIntegerLikeValue(y)); + return getAttr(resultType, std::trunc(xValue / yValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ExpOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult ExpOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getExponent(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getExponent(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::exp(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::exp(operandValue)); - } + return getAttr(resultType, std::exp(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::exp(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // FloorOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult FloorOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult FloorOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t operandValue = getScalarIntegerLikeValue(operand); - return getAttr(resultType, operandValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::floor(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t operandValue = getScalarIntegerLikeValue(operand); + return getAttr(resultType, operandValue); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::floor(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // IntegerOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult IntegerOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult IntegerOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } - - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t operandValue = getScalarIntegerLikeValue(operand); - return getAttr(resultType, operandValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::floor(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t operandValue = getScalarIntegerLikeValue(operand); + return getAttr(resultType, operandValue); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::floor(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // LogOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult LogOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult LogOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::log(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::log(operandValue)); - } + return getAttr(resultType, std::log(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::log(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Log10Op -namespace mlir::bmodelica -{ - mlir::OpFoldResult Log10Op::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult Log10Op::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } + if (!operand) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - return getAttr(resultType, std::log10(operandValue)); - } - - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::log10(operandValue)); - } + return getAttr(resultType, std::log10(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::log10(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // MaxOp -namespace mlir::bmodelica -{ - mlir::LogicalResult MaxOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type firstType = adaptor.getFirst().getType(); - - if (adaptor.getSecond()) { - mlir::Type secondType = adaptor.getSecond().getType(); - - if (isScalar(firstType) && isScalar(secondType)) { - returnTypes.push_back(getMostGenericScalarType(firstType, secondType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult MaxOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type firstType = adaptor.getFirst().getType(); - return mlir::failure(); - } + if (adaptor.getSecond()) { + mlir::Type secondType = adaptor.getSecond().getType(); - if (auto firstShapedType = firstType.dyn_cast()) { - returnTypes.push_back(firstShapedType.getElementType()); + if (isScalar(firstType) && isScalar(secondType)) { + returnTypes.push_back(getMostGenericScalarType(firstType, secondType)); return mlir::success(); } return mlir::failure(); } - bool MaxOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { - return false; - } + if (auto firstShapedType = firstType.dyn_cast()) { + returnTypes.push_back(firstShapedType.getElementType()); + return mlir::success(); + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return mlir::failure(); +} - return true; +bool MaxOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - mlir::ParseResult MaxOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::OpAsmParser::UnresolvedOperand first; - mlir::Type firstType; + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { + return false; + } + } - mlir::OpAsmParser::UnresolvedOperand second; - mlir::Type secondType; + return true; +} - size_t numOperands = 1; +mlir::ParseResult MaxOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand first; + mlir::Type firstType; - if (parser.parseOperand(first)) { - return mlir::failure(); - } + mlir::OpAsmParser::UnresolvedOperand second; + mlir::Type secondType; - if (mlir::succeeded(parser.parseOptionalComma())) { - numOperands = 2; + size_t numOperands = 1; - if (parser.parseOperand(second)) { - return mlir::failure(); - } - } + if (parser.parseOperand(first)) { + return mlir::failure(); + } - if (parser.parseOptionalAttrDict(result.attributes)) { - return mlir::failure(); - } + if (mlir::succeeded(parser.parseOptionalComma())) { + numOperands = 2; - if (parser.parseColon()) { + if (parser.parseOperand(second)) { return mlir::failure(); } + } - if (numOperands == 1) { - if (parser.parseType(firstType) || - parser.resolveOperand(first, firstType, result.operands)) { - return mlir::failure(); - } - } else { - if (parser.parseLParen() || - parser.parseType(firstType) || - parser.resolveOperand(first, firstType, result.operands) || - parser.parseComma() || - parser.parseType(secondType) || - parser.resolveOperand(second, secondType, result.operands) || - parser.parseRParen()) { - return mlir::failure(); - } - } + if (parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); + } - mlir::Type resultType; + if (parser.parseColon()) { + return mlir::failure(); + } - if (parser.parseArrow() || - parser.parseType(resultType)) { + if (numOperands == 1) { + if (parser.parseType(firstType) || + parser.resolveOperand(first, firstType, result.operands)) { + return mlir::failure(); + } + } else { + if (parser.parseLParen() || parser.parseType(firstType) || + parser.resolveOperand(first, firstType, result.operands) || + parser.parseComma() || parser.parseType(secondType) || + parser.resolveOperand(second, secondType, result.operands) || + parser.parseRParen()) { return mlir::failure(); } + } - result.addTypes(resultType); + mlir::Type resultType; - return mlir::success(); + if (parser.parseArrow() || parser.parseType(resultType)) { + return mlir::failure(); } - void MaxOp::print(mlir::OpAsmPrinter& printer) - { - printer << getFirst(); + result.addTypes(resultType); - if (getOperation()->getNumOperands() == 2) { - printer << ", " << getSecond(); - } + return mlir::success(); +} - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : "; +void MaxOp::print(mlir::OpAsmPrinter &printer) { + printer << getFirst(); - if (getOperation()->getNumOperands() == 1) { - printer << getFirst().getType(); - } else { - printer << "(" << getFirst().getType() << ", " - << getSecond().getType() << ")"; - } + if (getOperation()->getNumOperands() == 2) { + printer << ", " << getSecond(); + } - printer << " -> " << getResult().getType(); + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : "; + + if (getOperation()->getNumOperands() == 1) { + printer << getFirst().getType(); + } else { + printer << "(" << getFirst().getType() << ", " << getSecond().getType() + << ")"; } - mlir::OpFoldResult MaxOp::fold(FoldAdaptor adaptor) - { - if (adaptor.getOperands().size() == 2) { - auto first = adaptor.getFirst(); - auto second = adaptor.getSecond(); + printer << " -> " << getResult().getType(); +} + +mlir::OpFoldResult MaxOp::fold(FoldAdaptor adaptor) { + if (adaptor.getOperands().size() == 2) { + auto first = adaptor.getFirst(); + auto second = adaptor.getSecond(); - if (!first || !second) { - return {}; - } + if (!first || !second) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(first) && isScalar(second)) { - if (isScalarIntegerLike(first) && isScalarIntegerLike(second)) { - int64_t firstValue = getScalarIntegerLikeValue(first); - int64_t secondValue = getScalarIntegerLikeValue(second); - return getAttr(resultType, std::max(firstValue, secondValue)); - } + if (isScalar(first) && isScalar(second)) { + if (isScalarIntegerLike(first) && isScalarIntegerLike(second)) { + int64_t firstValue = getScalarIntegerLikeValue(first); + int64_t secondValue = getScalarIntegerLikeValue(second); + return getAttr(resultType, std::max(firstValue, secondValue)); + } - if (isScalarFloatLike(first) && isScalarFloatLike(second)) { - double firstValue = getScalarFloatLikeValue(first); - double secondValue = getScalarFloatLikeValue(second); - return getAttr(resultType, std::max(firstValue, secondValue)); - } + if (isScalarFloatLike(first) && isScalarFloatLike(second)) { + double firstValue = getScalarFloatLikeValue(first); + double secondValue = getScalarFloatLikeValue(second); + return getAttr(resultType, std::max(firstValue, secondValue)); + } - if (isScalarIntegerLike(first) && isScalarFloatLike(second)) { - auto firstValue = - static_cast(getScalarIntegerLikeValue(first)); + if (isScalarIntegerLike(first) && isScalarFloatLike(second)) { + auto firstValue = static_cast(getScalarIntegerLikeValue(first)); - double secondValue = getScalarFloatLikeValue(second); + double secondValue = getScalarFloatLikeValue(second); - if (firstValue >= secondValue) { - return getAttr(resultType, firstValue); - } else { - return getAttr(resultType, secondValue); - } + if (firstValue >= secondValue) { + return getAttr(resultType, firstValue); + } else { + return getAttr(resultType, secondValue); } + } - if (isScalarFloatLike(first) && isScalarIntegerLike(second)) { - double firstValue = getScalarFloatLikeValue(first); + if (isScalarFloatLike(first) && isScalarIntegerLike(second)) { + double firstValue = getScalarFloatLikeValue(first); - auto secondValue = - static_cast(getScalarIntegerLikeValue(second)); + auto secondValue = + static_cast(getScalarIntegerLikeValue(second)); - if (firstValue >= secondValue) { - return getAttr(resultType, firstValue); - } else { - return getAttr(resultType, secondValue); - } + if (firstValue >= secondValue) { + return getAttr(resultType, firstValue); + } else { + return getAttr(resultType, secondValue); } } } - - return {}; } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // MinOp -namespace mlir::bmodelica -{ - mlir::LogicalResult MinOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - mlir::Type firstType = adaptor.getFirst().getType(); - - if (adaptor.getSecond()) { - mlir::Type secondType = adaptor.getSecond().getType(); - - if (isScalar(firstType) && isScalar(secondType)) { - returnTypes.push_back(getMostGenericScalarType(firstType, secondType)); - return mlir::success(); - } +namespace mlir::bmodelica { +mlir::LogicalResult MinOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + mlir::Type firstType = adaptor.getFirst().getType(); - return mlir::failure(); - } + if (adaptor.getSecond()) { + mlir::Type secondType = adaptor.getSecond().getType(); - if (auto firstShapedType = firstType.dyn_cast()) { - returnTypes.push_back(firstShapedType.getElementType()); + if (isScalar(firstType) && isScalar(secondType)) { + returnTypes.push_back(getMostGenericScalarType(firstType, secondType)); return mlir::success(); } return mlir::failure(); } - bool MinOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { - return false; - } + if (auto firstShapedType = firstType.dyn_cast()) { + returnTypes.push_back(firstShapedType.getElementType()); + return mlir::success(); + } - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } + return mlir::failure(); +} - return true; +bool MinOp::isCompatibleReturnTypes(mlir::TypeRange lhs, mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - mlir::ParseResult MinOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::OpAsmParser::UnresolvedOperand first; - mlir::Type firstType; + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { + return false; + } + } - mlir::OpAsmParser::UnresolvedOperand second; - mlir::Type secondType; + return true; +} - size_t numOperands = 1; +mlir::ParseResult MinOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand first; + mlir::Type firstType; - if (parser.parseOperand(first)) { - return mlir::failure(); - } + mlir::OpAsmParser::UnresolvedOperand second; + mlir::Type secondType; - if (mlir::succeeded(parser.parseOptionalComma())) { - numOperands = 2; + size_t numOperands = 1; - if (parser.parseOperand(second)) { - return mlir::failure(); - } - } + if (parser.parseOperand(first)) { + return mlir::failure(); + } - if (parser.parseOptionalAttrDict(result.attributes)) { - return mlir::failure(); - } + if (mlir::succeeded(parser.parseOptionalComma())) { + numOperands = 2; - if (parser.parseColon()) { + if (parser.parseOperand(second)) { return mlir::failure(); } + } - if (numOperands == 1) { - if (parser.parseType(firstType) || - parser.resolveOperand(first, firstType, result.operands)) { - return mlir::failure(); - } - } else { - if (parser.parseLParen() || - parser.parseType(firstType) || - parser.resolveOperand(first, firstType, result.operands) || - parser.parseComma() || - parser.parseType(secondType) || - parser.resolveOperand(second, secondType, result.operands) || - parser.parseRParen()) { - return mlir::failure(); - } - } + if (parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); + } - mlir::Type resultType; + if (parser.parseColon()) { + return mlir::failure(); + } - if (parser.parseArrow() || - parser.parseType(resultType)) { + if (numOperands == 1) { + if (parser.parseType(firstType) || + parser.resolveOperand(first, firstType, result.operands)) { + return mlir::failure(); + } + } else { + if (parser.parseLParen() || parser.parseType(firstType) || + parser.resolveOperand(first, firstType, result.operands) || + parser.parseComma() || parser.parseType(secondType) || + parser.resolveOperand(second, secondType, result.operands) || + parser.parseRParen()) { return mlir::failure(); } + } - result.addTypes(resultType); + mlir::Type resultType; - return mlir::success(); + if (parser.parseArrow() || parser.parseType(resultType)) { + return mlir::failure(); } - void MinOp::print(mlir::OpAsmPrinter& printer) - { - printer << getFirst(); + result.addTypes(resultType); - if (getOperation()->getNumOperands() == 2) { - printer << ", " << getSecond(); - } + return mlir::success(); +} - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : "; +void MinOp::print(mlir::OpAsmPrinter &printer) { + printer << getFirst(); - if (getOperation()->getNumOperands() == 1) { - printer << getFirst().getType(); - } else { - printer << "(" << getFirst().getType() << ", " - << getSecond().getType() << ")"; - } + if (getOperation()->getNumOperands() == 2) { + printer << ", " << getSecond(); + } - printer << " -> " << getResult().getType(); + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : "; + + if (getOperation()->getNumOperands() == 1) { + printer << getFirst().getType(); + } else { + printer << "(" << getFirst().getType() << ", " << getSecond().getType() + << ")"; } - mlir::OpFoldResult MinOp::fold(FoldAdaptor adaptor) - { - if (adaptor.getOperands().size() == 2) { - auto first = adaptor.getFirst(); - auto second = adaptor.getSecond(); + printer << " -> " << getResult().getType(); +} + +mlir::OpFoldResult MinOp::fold(FoldAdaptor adaptor) { + if (adaptor.getOperands().size() == 2) { + auto first = adaptor.getFirst(); + auto second = adaptor.getSecond(); - if (!first || !second) { - return {}; - } + if (!first || !second) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(first) && isScalar(second)) { - if (isScalarIntegerLike(first) && isScalarIntegerLike(second)) { - int64_t firstValue = getScalarIntegerLikeValue(first); - int64_t secondValue = getScalarIntegerLikeValue(second); - return getAttr(resultType, std::min(firstValue, secondValue)); - } + if (isScalar(first) && isScalar(second)) { + if (isScalarIntegerLike(first) && isScalarIntegerLike(second)) { + int64_t firstValue = getScalarIntegerLikeValue(first); + int64_t secondValue = getScalarIntegerLikeValue(second); + return getAttr(resultType, std::min(firstValue, secondValue)); + } - if (isScalarFloatLike(first) && isScalarFloatLike(second)) { - double firstValue = getScalarFloatLikeValue(first); - double secondValue = getScalarFloatLikeValue(second); - return getAttr(resultType, std::min(firstValue, secondValue)); - } + if (isScalarFloatLike(first) && isScalarFloatLike(second)) { + double firstValue = getScalarFloatLikeValue(first); + double secondValue = getScalarFloatLikeValue(second); + return getAttr(resultType, std::min(firstValue, secondValue)); + } - if (isScalarIntegerLike(first) && isScalarFloatLike(second)) { - auto firstValue = - static_cast(getScalarIntegerLikeValue(first)); + if (isScalarIntegerLike(first) && isScalarFloatLike(second)) { + auto firstValue = static_cast(getScalarIntegerLikeValue(first)); - double secondValue = getScalarFloatLikeValue(second); + double secondValue = getScalarFloatLikeValue(second); - if (firstValue <= secondValue) { - return getAttr(resultType, firstValue); - } else { - return getAttr(resultType, secondValue); - } + if (firstValue <= secondValue) { + return getAttr(resultType, firstValue); + } else { + return getAttr(resultType, secondValue); } + } - if (isScalarFloatLike(first) && isScalarIntegerLike(second)) { - double firstValue = getScalarFloatLikeValue(first); + if (isScalarFloatLike(first) && isScalarIntegerLike(second)) { + double firstValue = getScalarFloatLikeValue(first); - auto secondValue = - static_cast(getScalarIntegerLikeValue(second)); + auto secondValue = + static_cast(getScalarIntegerLikeValue(second)); - if (firstValue <= secondValue) { - return getAttr(resultType, firstValue); - } else { - return getAttr(resultType, secondValue); - } + if (firstValue <= secondValue) { + return getAttr(resultType, firstValue); + } else { + return getAttr(resultType, secondValue); } } } - - return {}; } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ModOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult ModOp::fold(FoldAdaptor adaptor) - { - auto x = adaptor.getX(); - auto y = adaptor.getY(); +namespace mlir::bmodelica { +mlir::OpFoldResult ModOp::fold(FoldAdaptor adaptor) { + auto x = adaptor.getX(); + auto y = adaptor.getY(); - if (!x || !y) { - return {}; - } + if (!x || !y) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(x) && isScalar(y)) { - if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { - auto xValue = static_cast(getScalarIntegerLikeValue(x)); - auto yValue = static_cast(getScalarIntegerLikeValue(y)); + if (isScalar(x) && isScalar(y)) { + if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { + auto xValue = static_cast(getScalarIntegerLikeValue(x)); + auto yValue = static_cast(getScalarIntegerLikeValue(y)); - return getAttr( - resultType, - xValue - std::floor(xValue / yValue) * yValue); - } + return getAttr(resultType, xValue - std::floor(xValue / yValue) * yValue); + } - if (isScalarFloatLike(x) && isScalarFloatLike(y)) { - double xValue = getScalarFloatLikeValue(x); - double yValue = getScalarFloatLikeValue(y); + if (isScalarFloatLike(x) && isScalarFloatLike(y)) { + double xValue = getScalarFloatLikeValue(x); + double yValue = getScalarFloatLikeValue(y); - return getAttr( - resultType, - xValue - std::floor(xValue / yValue) * yValue); - } + return getAttr(resultType, xValue - std::floor(xValue / yValue) * yValue); + } - if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { - auto xValue = static_cast(getScalarIntegerLikeValue(x)); - double yValue = getScalarFloatLikeValue(y); + if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { + auto xValue = static_cast(getScalarIntegerLikeValue(x)); + double yValue = getScalarFloatLikeValue(y); - return getAttr( - resultType, - xValue - std::floor(xValue / yValue) * yValue); - } + return getAttr(resultType, xValue - std::floor(xValue / yValue) * yValue); + } - if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { - double xValue = getScalarFloatLikeValue(x); - auto yValue = static_cast(getScalarIntegerLikeValue(y)); + if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { + double xValue = getScalarFloatLikeValue(x); + auto yValue = static_cast(getScalarIntegerLikeValue(y)); - return getAttr( - resultType, - xValue - std::floor(xValue / yValue) * yValue); - } + return getAttr(resultType, xValue - std::floor(xValue / yValue) * yValue); } - - return {}; } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RemOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult RemOp::fold(FoldAdaptor adaptor) - { - auto x = adaptor.getX(); - auto y = adaptor.getY(); - - if (!x || !y) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult RemOp::fold(FoldAdaptor adaptor) { + auto x = adaptor.getX(); + auto y = adaptor.getY(); - auto resultType = getResult().getType(); + if (!x || !y) { + return {}; + } - if (isScalar(x) && isScalar(y)) { - if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { - int64_t xValue = getScalarIntegerLikeValue(x); - int64_t yValue = getScalarIntegerLikeValue(y); - return getAttr(resultType, xValue % yValue); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(x) && isScalarFloatLike(y)) { - double xValue = getScalarFloatLikeValue(x); - double yValue = getScalarFloatLikeValue(y); - return getAttr(resultType, std::fmod(xValue, yValue)); - } + if (isScalar(x) && isScalar(y)) { + if (isScalarIntegerLike(x) && isScalarIntegerLike(y)) { + int64_t xValue = getScalarIntegerLikeValue(x); + int64_t yValue = getScalarIntegerLikeValue(y); + return getAttr(resultType, xValue % yValue); + } - if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { - auto xValue = static_cast(getScalarIntegerLikeValue(x)); - double yValue = getScalarFloatLikeValue(y); - return getAttr(resultType, std::fmod(xValue, yValue)); - } + if (isScalarFloatLike(x) && isScalarFloatLike(y)) { + double xValue = getScalarFloatLikeValue(x); + double yValue = getScalarFloatLikeValue(y); + return getAttr(resultType, std::fmod(xValue, yValue)); + } - if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { - double xValue = getScalarFloatLikeValue(x); - auto yValue = static_cast(getScalarIntegerLikeValue(y)); - return getAttr(resultType, std::fmod(xValue, yValue)); - } + if (isScalarIntegerLike(x) && isScalarFloatLike(y)) { + auto xValue = static_cast(getScalarIntegerLikeValue(x)); + double yValue = getScalarFloatLikeValue(y); + return getAttr(resultType, std::fmod(xValue, yValue)); } - return {}; + if (isScalarFloatLike(x) && isScalarIntegerLike(y)) { + double xValue = getScalarFloatLikeValue(x); + auto yValue = static_cast(getScalarIntegerLikeValue(y)); + return getAttr(resultType, std::fmod(xValue, yValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SignOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult SignOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); +namespace mlir::bmodelica { +mlir::OpFoldResult SignOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - if (!operand) { - return {}; - } + if (!operand) { + return {}; + } - auto resultType = getResult().getType(); + auto resultType = getResult().getType(); - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t value = getScalarIntegerLikeValue(operand); + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t value = getScalarIntegerLikeValue(operand); - if (value == 0) { - return getAttr(resultType, static_cast(0)); - } else if (value > 0) { - return getAttr(resultType, static_cast(1)); - } else { - return getAttr(resultType, static_cast(-1)); - } + if (value == 0) { + return getAttr(resultType, static_cast(0)); + } else if (value > 0) { + return getAttr(resultType, static_cast(1)); + } else { + return getAttr(resultType, static_cast(-1)); } + } - if (isScalarFloatLike(operand)) { - double value = getScalarFloatLikeValue(operand); + if (isScalarFloatLike(operand)) { + double value = getScalarFloatLikeValue(operand); - if (value == 0) { - return getAttr(resultType, static_cast(0)); - } else if (value > 0) { - return getAttr(resultType, static_cast(1)); - } else { - return getAttr(resultType, static_cast(-1)); - } + if (value == 0) { + return getAttr(resultType, static_cast(0)); + } else if (value > 0) { + return getAttr(resultType, static_cast(1)); + } else { + return getAttr(resultType, static_cast(-1)); } } - - return {}; } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SinOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult SinOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult SinOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::sin(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::sin(operandValue)); - } + return getAttr(resultType, std::sin(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::sin(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SinhOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult SinhOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult SinhOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::sinh(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::sinh(operandValue)); - } + return getAttr(resultType, std::sinh(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::sinh(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SizeOp -namespace mlir::bmodelica -{ - mlir::ParseResult SizeOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::OpAsmParser::UnresolvedOperand array; - mlir::Type tensorType; - - mlir::OpAsmParser::UnresolvedOperand dimension; - mlir::Type dimensionType; - - size_t numOperands = 1; +namespace mlir::bmodelica { +mlir::ParseResult SizeOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand array; + mlir::Type tensorType; - if (parser.parseOperand(array)) { - return mlir::failure(); - } + mlir::OpAsmParser::UnresolvedOperand dimension; + mlir::Type dimensionType; - if (mlir::succeeded(parser.parseOptionalComma())) { - numOperands = 2; + size_t numOperands = 1; - if (parser.parseOperand(dimension)) { - return mlir::failure(); - } - } + if (parser.parseOperand(array)) { + return mlir::failure(); + } - if (parser.parseOptionalAttrDict(result.attributes)) { - return mlir::failure(); - } + if (mlir::succeeded(parser.parseOptionalComma())) { + numOperands = 2; - if (parser.parseColon()) { + if (parser.parseOperand(dimension)) { return mlir::failure(); } + } - if (numOperands == 1) { - if (parser.parseType(tensorType) || - parser.resolveOperand(array, tensorType, result.operands)) { - return mlir::failure(); - } - } else { - if (parser.parseLParen() || - parser.parseType(tensorType) || - parser.resolveOperand(array, tensorType, result.operands) || - parser.parseComma() || - parser.parseType(dimensionType) || - parser.resolveOperand(dimension, dimensionType, result.operands) || - parser.parseRParen()) { - return mlir::failure(); - } - } + if (parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); + } - mlir::Type resultType; + if (parser.parseColon()) { + return mlir::failure(); + } - if (parser.parseArrow() || - parser.parseType(resultType)) { + if (numOperands == 1) { + if (parser.parseType(tensorType) || + parser.resolveOperand(array, tensorType, result.operands)) { + return mlir::failure(); + } + } else { + if (parser.parseLParen() || parser.parseType(tensorType) || + parser.resolveOperand(array, tensorType, result.operands) || + parser.parseComma() || parser.parseType(dimensionType) || + parser.resolveOperand(dimension, dimensionType, result.operands) || + parser.parseRParen()) { return mlir::failure(); } + } - result.addTypes(resultType); + mlir::Type resultType; - return mlir::success(); + if (parser.parseArrow() || parser.parseType(resultType)) { + return mlir::failure(); } - void SizeOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << getArray(); + result.addTypes(resultType); - if (getOperation()->getNumOperands() == 2) { - printer << ", " << getDimension(); - } + return mlir::success(); +} - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " : "; +void SizeOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << getArray(); - if (getOperation()->getNumOperands() == 1) { - printer << getArray().getType(); - } else { - printer << "(" << getArray().getType() << ", " - << getDimension().getType() << ")"; - } + if (getOperation()->getNumOperands() == 2) { + printer << ", " << getDimension(); + } - printer << " -> " << getResult().getType(); + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " : "; + + if (getOperation()->getNumOperands() == 1) { + printer << getArray().getType(); + } else { + printer << "(" << getArray().getType() << ", " << getDimension().getType() + << ")"; } + + printer << " -> " << getResult().getType(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SqrtOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::sqrt(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::sqrt(operandValue)); - } + return getAttr(resultType, std::sqrt(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::sqrt(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TanOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult TanOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult TanOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::tan(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::tan(operandValue)); - } + return getAttr(resultType, std::tan(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::tan(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TanhOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult TanhOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getOperand(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult TanhOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getOperand(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - auto operandValue = - static_cast(getScalarIntegerLikeValue(operand)); + auto resultType = getResult().getType(); - return getAttr(resultType, std::tanh(operandValue)); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + auto operandValue = + static_cast(getScalarIntegerLikeValue(operand)); - if (isScalarFloatLike(operand)) { - double operandValue = getScalarFloatLikeValue(operand); - return getAttr(resultType, std::tanh(operandValue)); - } + return getAttr(resultType, std::tanh(operandValue)); } - return {}; + if (isScalarFloatLike(operand)) { + double operandValue = getScalarFloatLikeValue(operand); + return getAttr(resultType, std::tanh(operandValue)); + } } + + return {}; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // TransposeOp -namespace mlir::bmodelica -{ - mlir::LogicalResult TransposeOp::inferReturnTypes( - mlir::MLIRContext* context, - std::optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::OpaqueProperties properties, - mlir::RegionRange regions, - llvm::SmallVectorImpl& returnTypes) - { - Adaptor adaptor(operands, attributes, properties, regions); - - auto matrixShapedType = - adaptor.getMatrix().getType().dyn_cast(); - - if (!matrixShapedType || matrixShapedType.getRank() != 2) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::LogicalResult TransposeOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + mlir::ValueRange operands, mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, mlir::RegionRange regions, + llvm::SmallVectorImpl &returnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); - llvm::SmallVector shape; - shape.push_back(matrixShapedType.getDimSize(1)); - shape.push_back(matrixShapedType.getDimSize(0)); + auto matrixShapedType = + adaptor.getMatrix().getType().dyn_cast(); - returnTypes.push_back( - mlir::cast(matrixShapedType.clone(shape))); + if (!matrixShapedType || matrixShapedType.getRank() != 2) { + return mlir::failure(); + } - return mlir::success(); + llvm::SmallVector shape; + shape.push_back(matrixShapedType.getDimSize(1)); + shape.push_back(matrixShapedType.getDimSize(0)); + + returnTypes.push_back(mlir::cast(matrixShapedType.clone(shape))); + + return mlir::success(); +} + +bool TransposeOp::isCompatibleReturnTypes(mlir::TypeRange lhs, + mlir::TypeRange rhs) { + if (lhs.size() != rhs.size()) { + return false; } - bool TransposeOp::isCompatibleReturnTypes( - mlir::TypeRange lhs, mlir::TypeRange rhs) - { - if (lhs.size() != rhs.size()) { + for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { + if (!areTypesCompatible(lhsType, rhsType)) { return false; } - - for (auto [lhsType, rhsType] : llvm::zip(lhs, rhs)) { - if (!areTypesCompatible(lhsType, rhsType)) { - return false; - } - } - - return true; } + + return true; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ReductionOp -namespace mlir::bmodelica -{ - mlir::ParseResult ReductionOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto loc = parser.getCurrentLocation(); - std::string action; - - if (parser.parseString(&action) || - parser.parseComma()) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::ParseResult ReductionOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto loc = parser.getCurrentLocation(); + std::string action; - result.addAttribute( - getActionAttrName(result.name), - parser.getBuilder().getStringAttr(action)); - - llvm::SmallVector iterables; - llvm::SmallVector iterablesTypes; - - llvm::SmallVector inductions; - - mlir::Region* expressionRegion = result.addRegion(); - mlir::Type resultType; - - if (parser.parseKeyword("iterables") || - parser.parseEqual() || - parser.parseOperandList( - iterables, mlir::AsmParser::Delimiter::Square) || - parser.parseComma() || - parser.parseKeyword("inductions") || - parser.parseEqual() || - parser.parseArgumentList( - inductions, mlir::AsmParser::Delimiter::Square, true) || - parser.parseOptionalAttrDictWithKeyword(result.attributes) || - parser.parseRegion(*expressionRegion, inductions) || - parser.parseColon() || - parser.parseLParen() || - parser.parseTypeList(iterablesTypes) || - parser.parseRParen() || - parser.resolveOperands( - iterables, iterablesTypes, loc, result.operands) || - parser.parseArrow() || - parser.parseType(resultType)) { - return mlir::failure(); - } + if (parser.parseString(&action) || parser.parseComma()) { + return mlir::failure(); + } - result.addTypes(resultType); + result.addAttribute(getActionAttrName(result.name), + parser.getBuilder().getStringAttr(action)); - return mlir::success(); + llvm::SmallVector iterables; + llvm::SmallVector iterablesTypes; + + llvm::SmallVector inductions; + + mlir::Region *expressionRegion = result.addRegion(); + mlir::Type resultType; + + if (parser.parseKeyword("iterables") || parser.parseEqual() || + parser.parseOperandList(iterables, mlir::AsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseKeyword("inductions") || + parser.parseEqual() || + parser.parseArgumentList(inductions, mlir::AsmParser::Delimiter::Square, + true) || + parser.parseOptionalAttrDictWithKeyword(result.attributes) || + parser.parseRegion(*expressionRegion, inductions) || + parser.parseColon() || parser.parseLParen() || + parser.parseTypeList(iterablesTypes) || parser.parseRParen() || + parser.resolveOperands(iterables, iterablesTypes, loc, result.operands) || + parser.parseArrow() || parser.parseType(resultType)) { + return mlir::failure(); } - void ReductionOp::print(mlir::OpAsmPrinter& printer) - { - printer << " \"" << getAction() - << "\", iterables = [" << getIterables() - << "], inductions = ["; + result.addTypes(resultType); - for (size_t i = 0, e = getInductions().size(); i < e; ++i) { - if (i != 0) { - printer << ", "; - } + return mlir::success(); +} + +void ReductionOp::print(mlir::OpAsmPrinter &printer) { + printer << " \"" << getAction() << "\", iterables = [" << getIterables() + << "], inductions = ["; - printer.printRegionArgument(getInductions()[i]); + for (size_t i = 0, e = getInductions().size(); i < e; ++i) { + if (i != 0) { + printer << ", "; } - printer << "] "; + printer.printRegionArgument(getInductions()[i]); + } - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(getActionAttrName().getValue()); + printer << "] "; - printer.printOptionalAttrDictWithKeyword( - getOperation()->getAttrs(), elidedAttrs); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(getActionAttrName().getValue()); - printer.printRegion(getExpressionRegion(), false); - printer << " : "; + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + elidedAttrs); - auto iterables = getIterables(); - printer << "("; + printer.printRegion(getExpressionRegion(), false); + printer << " : "; - for (size_t i = 0, e = iterables.size(); i < e; ++i) { - if (i != 0) { - printer << ", "; - } + auto iterables = getIterables(); + printer << "("; - printer << iterables[i].getType(); + for (size_t i = 0, e = iterables.size(); i < e; ++i) { + if (i != 0) { + printer << ", "; } - printer << ") -> "; - - printer << getResult().getType(); + printer << iterables[i].getType(); } - mlir::Block* ReductionOp::createExpressionBlock(mlir::OpBuilder& builder) - { - mlir::OpBuilder::InsertionGuard guard(builder); + printer << ") -> "; - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; + printer << getResult().getType(); +} - for (mlir::Value iterable : getIterables()) { - auto iterableType = iterable.getType().cast(); - argTypes.push_back(iterableType.getInductionType()); - argLocs.push_back(builder.getUnknownLoc()); - } +mlir::Block *ReductionOp::createExpressionBlock(mlir::OpBuilder &builder) { + mlir::OpBuilder::InsertionGuard guard(builder); - return builder.createBlock(&getExpressionRegion(), {}, argTypes, argLocs); + llvm::SmallVector argTypes; + llvm::SmallVector argLocs; + + for (mlir::Value iterable : getIterables()) { + auto iterableType = iterable.getType().cast(); + argTypes.push_back(iterableType.getInductionType()); + argLocs.push_back(builder.getUnknownLoc()); } + + return builder.createBlock(&getExpressionRegion(), {}, argTypes, argLocs); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Modeling operations @@ -8010,4314 +7373,3905 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // PackageOp -namespace mlir::bmodelica -{ - void PackageOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name) - { - state.addRegion()->emplaceBlock(); - - state.attributes.push_back(builder.getNamedAttr( - mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name))); - } - - mlir::ParseResult PackageOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::StringAttr nameAttr; - - if (parser.parseSymbolName( - nameAttr, - mlir::SymbolTable::getSymbolAttrName(), - result.attributes) || - parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +void PackageOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name) { + state.addRegion()->emplaceBlock(); - mlir::Region* bodyRegion = result.addRegion(); + state.attributes.push_back(builder.getNamedAttr( + mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); +} - if (parser.parseRegion(*bodyRegion)) { - return mlir::failure(); - } +mlir::ParseResult PackageOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::StringAttr nameAttr; - if (bodyRegion->empty()) { - bodyRegion->emplaceBlock(); - } + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); + } - return mlir::success(); + mlir::Region *bodyRegion = result.addRegion(); + + if (parser.parseRegion(*bodyRegion)) { + return mlir::failure(); } - void PackageOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printSymbolName(getSymName()); - printer << " "; + if (bodyRegion->empty()) { + bodyRegion->emplaceBlock(); + } - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); + return mlir::success(); +} - printer.printOptionalAttrDictWithKeyword( - getOperation()->getAttrs(), elidedAttrs); +void PackageOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getSymName()); + printer << " "; - printer.printRegion(getBodyRegion()); - } + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); - mlir::Block* PackageOp::bodyBlock() - { - assert(getBodyRegion().hasOneBlock()); - return &getBodyRegion().front(); - } + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + elidedAttrs); + + printer.printRegion(getBodyRegion()); +} + +mlir::Block *PackageOp::bodyBlock() { + assert(getBodyRegion().hasOneBlock()); + return &getBodyRegion().front(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ModelOp -namespace -{ - struct InitialModelMergePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ModelOp op, mlir::PatternRewriter& rewriter) const override - { - llvm::SmallVector initialOps; +namespace { +struct InitialModelMergePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - for (InitialOp initialOp : op.getOps()) { - initialOps.push_back(initialOp); - } + mlir::LogicalResult + matchAndRewrite(ModelOp op, mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector initialOps; - if (initialOps.size() <= 1) { - return mlir::failure(); - } + for (InitialOp initialOp : op.getOps()) { + initialOps.push_back(initialOp); + } - for (size_t i = 1, e = initialOps.size(); i < e; ++i) { - rewriter.mergeBlocks(initialOps[i].getBody(), - initialOps[0].getBody()); + if (initialOps.size() <= 1) { + return mlir::failure(); + } - rewriter.eraseOp(initialOps[i]); - } + for (size_t i = 1, e = initialOps.size(); i < e; ++i) { + rewriter.mergeBlocks(initialOps[i].getBody(), initialOps[0].getBody()); - return mlir::success(); + rewriter.eraseOp(initialOps[i]); } - }; - struct MainModelMergePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; + return mlir::success(); + } +}; - mlir::LogicalResult matchAndRewrite( - ModelOp op, mlir::PatternRewriter& rewriter) const override - { - llvm::SmallVector dynamicOps; +struct MainModelMergePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - for (DynamicOp dynamicOp : op.getOps()) { - dynamicOps.push_back(dynamicOp); - } + mlir::LogicalResult + matchAndRewrite(ModelOp op, mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector dynamicOps; - if (dynamicOps.size() <= 1) { - return mlir::failure(); - } + for (DynamicOp dynamicOp : op.getOps()) { + dynamicOps.push_back(dynamicOp); + } - for (size_t i = 1, e = dynamicOps.size(); i < e; ++i) { - rewriter.mergeBlocks(dynamicOps[i].getBody(), - dynamicOps[0].getBody()); + if (dynamicOps.size() <= 1) { + return mlir::failure(); + } - rewriter.eraseOp(dynamicOps[i]); - } + for (size_t i = 1, e = dynamicOps.size(); i < e; ++i) { + rewriter.mergeBlocks(dynamicOps[i].getBody(), dynamicOps[0].getBody()); - return mlir::success(); + rewriter.eraseOp(dynamicOps[i]); } - }; -} -namespace mlir::bmodelica -{ - void ModelOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + return mlir::success(); } +}; +} // namespace - mlir::RegionKind ModelOp::getRegionKind(unsigned index) - { - return mlir::RegionKind::Graph; - } +namespace mlir::bmodelica { +void ModelOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} - void ModelOp::getCleaningPatterns( - mlir::RewritePatternSet& patterns, - mlir::MLIRContext* context) - { - getCanonicalizationPatterns(patterns, context); - EquationTemplateOp::getCanonicalizationPatterns(patterns, context); - InitialOp::getCanonicalizationPatterns(patterns, context); - DynamicOp::getCanonicalizationPatterns(patterns, context); - SCCOp::getCanonicalizationPatterns(patterns, context); - } +mlir::RegionKind ModelOp::getRegionKind(unsigned index) { + return mlir::RegionKind::Graph; +} - void ModelOp::collectVariables(llvm::SmallVectorImpl& variables) - { - for (VariableOp variableOp : getVariables()) { - variables.push_back(variableOp); - } +void ModelOp::getCleaningPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + getCanonicalizationPatterns(patterns, context); + EquationTemplateOp::getCanonicalizationPatterns(patterns, context); + InitialOp::getCanonicalizationPatterns(patterns, context); + DynamicOp::getCanonicalizationPatterns(patterns, context); + SCCOp::getCanonicalizationPatterns(patterns, context); +} + +void ModelOp::collectVariables(llvm::SmallVectorImpl &variables) { + for (VariableOp variableOp : getVariables()) { + variables.push_back(variableOp); } +} - void ModelOp::collectInitialAlgorithms( - llvm::SmallVectorImpl& algorithms) - { - for (InitialOp initialOp : getOps()) { - initialOp.collectAlgorithms(algorithms); - } +void ModelOp::collectInitialAlgorithms( + llvm::SmallVectorImpl &algorithms) { + for (InitialOp initialOp : getOps()) { + initialOp.collectAlgorithms(algorithms); } +} - void ModelOp::collectMainAlgorithms( - llvm::SmallVectorImpl& algorithms) - { - for (DynamicOp dynamicOp : getOps()) { - dynamicOp.collectAlgorithms(algorithms); - } +void ModelOp::collectMainAlgorithms( + llvm::SmallVectorImpl &algorithms) { + for (DynamicOp dynamicOp : getOps()) { + dynamicOp.collectAlgorithms(algorithms); } +} - void ModelOp::collectInitialSCCs(llvm::SmallVectorImpl& SCCs) - { - for (InitialOp initialOp : getOps()) { - initialOp.collectSCCs(SCCs); - } +void ModelOp::collectInitialSCCs(llvm::SmallVectorImpl &SCCs) { + for (InitialOp initialOp : getOps()) { + initialOp.collectSCCs(SCCs); } +} - void ModelOp::collectMainSCCs(llvm::SmallVectorImpl& SCCs) - { - for (DynamicOp dynamicOp : getOps()) { - dynamicOp.collectSCCs(SCCs); - } +void ModelOp::collectMainSCCs(llvm::SmallVectorImpl &SCCs) { + for (DynamicOp dynamicOp : getOps()) { + dynamicOp.collectSCCs(SCCs); } +} - void ModelOp::collectSCCGroups( - llvm::SmallVectorImpl& initialSCCGroups, - llvm::SmallVectorImpl& SCCGroups) - { - for (SCCGroupOp op : getOps()) { - SCCGroups.push_back(op); - } +void ModelOp::collectSCCGroups( + llvm::SmallVectorImpl &initialSCCGroups, + llvm::SmallVectorImpl &SCCGroups) { + for (SCCGroupOp op : getOps()) { + SCCGroups.push_back(op); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // OperatorRecordOp -namespace mlir::bmodelica -{ - void OperatorRecordOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name) - { - state.addRegion()->emplaceBlock(); - - state.attributes.push_back(builder.getNamedAttr( - mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name))); - } - - mlir::ParseResult OperatorRecordOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::StringAttr nameAttr; - - if (parser.parseSymbolName( - nameAttr, - mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +void OperatorRecordOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + llvm::StringRef name) { + state.addRegion()->emplaceBlock(); - mlir::Region* bodyRegion = result.addRegion(); + state.attributes.push_back(builder.getNamedAttr( + mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); +} - if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || - parser.parseRegion(*bodyRegion)) { - return mlir::failure(); - } +mlir::ParseResult OperatorRecordOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::StringAttr nameAttr; - if (bodyRegion->empty()) { - bodyRegion->emplaceBlock(); - } + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) { + return mlir::failure(); + } - return mlir::success(); + mlir::Region *bodyRegion = result.addRegion(); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || + parser.parseRegion(*bodyRegion)) { + return mlir::failure(); + } + + if (bodyRegion->empty()) { + bodyRegion->emplaceBlock(); } - void OperatorRecordOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printSymbolName(getSymName()); - printer << " "; + return mlir::success(); +} - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); +void OperatorRecordOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getSymName()); + printer << " "; - printer.printOptionalAttrDictWithKeyword( - getOperation()->getAttrs(), elidedAttrs); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); - printer.printRegion(getBodyRegion()); - } + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + elidedAttrs); - mlir::Block* OperatorRecordOp::bodyBlock() - { - assert(getBodyRegion().hasOneBlock()); - return &getBodyRegion().front(); - } + printer.printRegion(getBodyRegion()); } +mlir::Block *OperatorRecordOp::bodyBlock() { + assert(getBodyRegion().hasOneBlock()); + return &getBodyRegion().front(); +} +} // namespace mlir::bmodelica + //===---------------------------------------------------------------------===// // StartOp -namespace mlir::bmodelica -{ - VariableOp StartOp::getVariableOp(mlir::SymbolTableCollection& symbolTable) - { - auto cls = getOperation()->getParentOfType(); - return symbolTable.lookupSymbolIn(cls, getVariableAttr()); - } - - mlir::LogicalResult StartOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - auto yieldOp = mlir::cast(getBody()->getTerminator()); +namespace mlir::bmodelica { +VariableOp StartOp::getVariableOp(mlir::SymbolTableCollection &symbolTable) { + auto cls = getOperation()->getParentOfType(); + return symbolTable.lookupSymbolIn(cls, getVariableAttr()); +} - llvm::DenseMap inductionsPositionMap; +mlir::LogicalResult +StartOp::getAccesses(llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + auto yieldOp = mlir::cast(getBody()->getTerminator()); - if (mlir::failed(searchAccesses( - result, symbolTable, inductionsPositionMap, - yieldOp.getValues()[0], - EquationPath(EquationPath::RIGHT, 0)))) { - return mlir::failure(); - } + llvm::DenseMap inductionsPositionMap; - return mlir::success(); + if (mlir::failed(searchAccesses(result, symbolTable, inductionsPositionMap, + yieldOp.getValues()[0], + EquationPath(EquationPath::RIGHT, 0)))) { + return mlir::failure(); } - mlir::LogicalResult StartOp::searchAccesses( - llvm::SmallVectorImpl& accesses, - mlir::SymbolTableCollection& symbolTable, - llvm::DenseMap& inductionsPositionMap, - mlir::Value value, - EquationPath path) - { - mlir::Operation* definingOp = value.getDefiningOp(); + return mlir::success(); +} - if (!definingOp) { - return mlir::success(); - } +mlir::LogicalResult StartOp::searchAccesses( + llvm::SmallVectorImpl &accesses, + mlir::SymbolTableCollection &symbolTable, + llvm::DenseMap &inductionsPositionMap, + mlir::Value value, EquationPath path) { + mlir::Operation *definingOp = value.getDefiningOp(); - AdditionalInductions additionalInductions; - llvm::SmallVector, 10> dimensionAccesses; + if (!definingOp) { + return mlir::success(); + } - if (auto expressionInt = - mlir::dyn_cast(definingOp)) { - return expressionInt.getEquationAccesses( - accesses, symbolTable, inductionsPositionMap, - additionalInductions, dimensionAccesses, - std::move(path)); - } + AdditionalInductions additionalInductions; + llvm::SmallVector, 10> dimensionAccesses; - return mlir::failure(); + if (auto expressionInt = + mlir::dyn_cast(definingOp)) { + return expressionInt.getEquationAccesses( + accesses, symbolTable, inductionsPositionMap, additionalInductions, + dimensionAccesses, std::move(path)); } + + return mlir::failure(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DefaultOp -namespace mlir::bmodelica -{ - VariableOp DefaultOp::getVariableOp(mlir::SymbolTableCollection& symbolTable) - { - auto cls = getOperation()->getParentOfType(); - return symbolTable.lookupSymbolIn(cls, getVariableAttr()); - } +namespace mlir::bmodelica { +VariableOp DefaultOp::getVariableOp(mlir::SymbolTableCollection &symbolTable) { + auto cls = getOperation()->getParentOfType(); + return symbolTable.lookupSymbolIn(cls, getVariableAttr()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // BindingEquationOp -namespace mlir::bmodelica -{ - VariableOp BindingEquationOp::getVariableOp( - mlir::SymbolTableCollection& symbolTable) - { - auto cls = getOperation()->getParentOfType(); - return symbolTable.lookupSymbolIn(cls, getVariableAttr()); - } +namespace mlir::bmodelica { +VariableOp +BindingEquationOp::getVariableOp(mlir::SymbolTableCollection &symbolTable) { + auto cls = getOperation()->getParentOfType(); + return symbolTable.lookupSymbolIn(cls, getVariableAttr()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ForEquationOp -namespace -{ - struct EmptyForEquationOpErasePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; +namespace { +struct EmptyForEquationOpErasePattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult match(ForEquationOp op) const override - { - return mlir::LogicalResult::success(op.getOps().empty()); - } + mlir::LogicalResult match(ForEquationOp op) const override { + return mlir::LogicalResult::success(op.getOps().empty()); + } - void rewrite( - ForEquationOp op, mlir::PatternRewriter& rewriter) const override - { - rewriter.eraseOp(op); - } - }; + void rewrite(ForEquationOp op, + mlir::PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + } +}; +} // namespace + +namespace mlir::bmodelica { +void ForEquationOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + long from, long to, long step) { + mlir::OpBuilder::InsertionGuard guard(builder); + + state.addAttribute(getFromAttrName(state.name), builder.getIndexAttr(from)); + + state.addAttribute(getToAttrName(state.name), builder.getIndexAttr(to)); + + state.addAttribute(getStepAttrName(state.name), builder.getIndexAttr(step)); + + mlir::Region *bodyRegion = state.addRegion(); + + builder.createBlock(bodyRegion, {}, builder.getIndexType(), + builder.getUnknownLoc()); } -namespace mlir::bmodelica -{ - void ForEquationOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - long from, - long to, - long step) - { - mlir::OpBuilder::InsertionGuard guard(builder); +void ForEquationOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); +} + +mlir::Block *ForEquationOp::bodyBlock() { + assert(getBodyRegion().getBlocks().size() == 1); + return &getBodyRegion().front(); +} - state.addAttribute( - getFromAttrName(state.name), builder.getIndexAttr(from)); +mlir::Value ForEquationOp::induction() { + assert(getBodyRegion().getNumArguments() != 0); + return getBodyRegion().getArgument(0); +} - state.addAttribute(getToAttrName(state.name), builder.getIndexAttr(to)); +mlir::ParseResult ForEquationOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); - state.addAttribute( - getStepAttrName(state.name), builder.getIndexAttr(step)); + mlir::OpAsmParser::Argument induction; - mlir::Region* bodyRegion = state.addRegion(); + int64_t from; + int64_t to; + int64_t step = 1; - builder.createBlock( - bodyRegion, {}, builder.getIndexType(), builder.getUnknownLoc()); + if (parser.parseArgument(induction) || parser.parseEqual() || + parser.parseInteger(from) || parser.parseKeyword("to") || + parser.parseInteger(to)) { + return mlir::failure(); } - void ForEquationOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + if (mlir::succeeded(parser.parseOptionalKeyword("step"))) { + if (parser.parseInteger(step)) { + return mlir::failure(); + } } - mlir::Block* ForEquationOp::bodyBlock() - { - assert(getBodyRegion().getBlocks().size() == 1); - return &getBodyRegion().front(); + induction.type = builder.getIndexType(); + + result.attributes.append("from", builder.getIndexAttr(from)); + result.attributes.append("to", builder.getIndexAttr(to)); + result.attributes.append("step", builder.getIndexAttr(step)); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); } - mlir::Value ForEquationOp::induction() - { - assert(getBodyRegion().getNumArguments() != 0); - return getBodyRegion().getArgument(0); + mlir::Region *bodyRegion = result.addRegion(); + + if (parser.parseRegion(*bodyRegion, induction)) { + return mlir::failure(); } - mlir::ParseResult ForEquationOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto& builder = parser.getBuilder(); + return mlir::success(); +} - mlir::OpAsmParser::Argument induction; +void ForEquationOp::print(mlir::OpAsmPrinter &printer) { + printer << " " << induction() << " = " << getFrom() << " to " << getTo(); - int64_t from; - int64_t to; - int64_t step = 1; + if (auto step = getStep(); step != 1) { + printer << " step " << step; + } - if (parser.parseArgument(induction) || - parser.parseEqual() || - parser.parseInteger(from) || - parser.parseKeyword("to") || - parser.parseInteger(to)) { - return mlir::failure(); - } + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + {"from", "to", "step"}); - if (mlir::succeeded(parser.parseOptionalKeyword("step"))) { - if (parser.parseInteger(step)) { - return mlir::failure(); - } - } + printer << " "; + printer.printRegion(getBodyRegion(), false); +} +} // namespace mlir::bmodelica - induction.type = builder.getIndexType(); +//===---------------------------------------------------------------------===// +// EquationTemplateOp - result.attributes.append("from", builder.getIndexAttr(from)); - result.attributes.append("to", builder.getIndexAttr(to)); - result.attributes.append("step", builder.getIndexAttr(step)); +namespace { +struct UnusedEquationTemplatePattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); + mlir::LogicalResult + matchAndRewrite(EquationTemplateOp op, + mlir::PatternRewriter &rewriter) const override { + if (op->use_empty()) { + rewriter.eraseOp(op); + return mlir::success(); } - mlir::Region* bodyRegion = result.addRegion(); + return mlir::failure(); + } +}; +} // namespace - if (parser.parseRegion(*bodyRegion, induction)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::ParseResult EquationTemplateOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + llvm::SmallVector inductions; + mlir::Region *bodyRegion = result.addRegion(); - return mlir::success(); + if (parser.parseKeyword("inductions") || parser.parseEqual() || + parser.parseArgumentList(inductions, + mlir::OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); } - void ForEquationOp::print(mlir::OpAsmPrinter& printer) - { - printer << " " << induction() << " = " << getFrom() << " to " << getTo(); + for (auto &induction : inductions) { + induction.type = parser.getBuilder().getIndexType(); + } - if (auto step = getStep(); step != 1) { - printer << " step " << step; - } + if (parser.parseRegion(*bodyRegion, inductions)) { + return mlir::failure(); + } + + if (bodyRegion->empty()) { + mlir::OpBuilder builder(bodyRegion); + + llvm::SmallVector argTypes(inductions.size(), + builder.getIndexType()); - printer.printOptionalAttrDictWithKeyword( - getOperation()->getAttrs(), {"from", "to", "step"}); + llvm::SmallVector argLocs(inductions.size(), + builder.getUnknownLoc()); - printer << " "; - printer.printRegion(getBodyRegion(), false); + builder.createBlock(bodyRegion, {}, argTypes, argLocs); } + + result.addTypes(EquationType::get(parser.getContext())); + return mlir::success(); } -//===---------------------------------------------------------------------===// -// EquationTemplateOp +void EquationTemplateOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer << "inductions = ["; + printer.printOperands(getInductionVariables()); + printer << "]"; + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); + printer << " "; + printer.printRegion(getBodyRegion(), false); +} -namespace -{ - struct UnusedEquationTemplatePattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - EquationTemplateOp op, mlir::PatternRewriter& rewriter) const override - { - if (op->use_empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } +void EquationTemplateOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); +} - return mlir::failure(); - } - }; +mlir::Block *EquationTemplateOp::createBody(unsigned int numOfInductions) { + mlir::OpBuilder builder(getContext()); + + llvm::SmallVector argTypes(numOfInductions, + builder.getIndexType()); + + llvm::SmallVector argLocs(numOfInductions, + builder.getUnknownLoc()); + + return builder.createBlock(&getBodyRegion(), {}, argTypes, argLocs); } -namespace mlir::bmodelica -{ - mlir::ParseResult EquationTemplateOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - llvm::SmallVector inductions; - mlir::Region* bodyRegion = result.addRegion(); - - if (parser.parseKeyword("inductions") || - parser.parseEqual() || - parser.parseArgumentList( - inductions, mlir::OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); +void EquationTemplateOp::printInline(llvm::raw_ostream &os) { + llvm::DenseMap inductions; + auto inductionVars = getInductionVariables(); + + for (size_t i = 0, e = inductionVars.size(); i < e; ++i) { + inductions[inductionVars[i]] = static_cast(i); + } + + if (auto expressionOp = mlir::cast( + getBody()->getTerminator())) { + expressionOp.printExpression(os, inductions); + } +} + +mlir::ValueRange EquationTemplateOp::getInductionVariables() { + return getBodyRegion().getArguments(); +} + +llvm::SmallVector +EquationTemplateOp::getInductionVariablesAtPath(const EquationPath &path) { + llvm::SmallVector result; + auto equationInductions = getInductionVariables(); + result.append(equationInductions.begin(), equationInductions.end()); + + mlir::Block *bodyBlock = getBody(); + EquationPath::EquationSide side = path.getEquationSide(); + + auto equationSidesOp = + mlir::cast(bodyBlock->getTerminator()); + + mlir::Value value = side == EquationPath::LEFT + ? equationSidesOp.getLhsValues()[path[0]] + : equationSidesOp.getRhsValues()[path[0]]; + + for (size_t i = 1, e = path.size(); i < e; ++i) { + mlir::Operation *op = value.getDefiningOp(); + assert(op != nullptr && "Invalid equation path"); + auto expressionInt = mlir::cast(op); + auto additionalInductions = expressionInt.getAdditionalInductions(); + result.append(additionalInductions.begin(), additionalInductions.end()); + value = expressionInt.getExpressionElement(path[i]); + } + + return result; +} + +llvm::DenseMap +EquationTemplateOp::getInductionsPositionMap() { + mlir::ValueRange inductionVariables = getInductionVariables(); + llvm::DenseMap inductionsPositionMap; + + for (auto inductionVariable : llvm::enumerate(inductionVariables)) { + inductionsPositionMap[inductionVariable.value()] = + inductionVariable.index(); + } + + return inductionsPositionMap; +} + +mlir::LogicalResult +EquationTemplateOp::getAccesses(llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); + + // Get the induction variables and number them. + auto inductionsPositionMap = getInductionsPositionMap(); + + // Search the accesses starting from the left-hand side of the equation. + if (mlir::failed(searchAccesses(result, symbolTable, inductionsPositionMap, + equationSidesOp.getLhsValues()[0], + EquationPath(EquationPath::LEFT, 0)))) { + return mlir::failure(); + } + + // Search the accesses starting from the right-hand side of the equation. + if (mlir::failed(searchAccesses(result, symbolTable, inductionsPositionMap, + equationSidesOp.getRhsValues()[0], + EquationPath(EquationPath::RIGHT, 0)))) { + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult EquationTemplateOp::getWriteAccesses( + llvm::SmallVectorImpl &result, + const IndexSet &equationIndices, llvm::ArrayRef accesses, + const VariableAccess &matchedAccess) { + const AccessFunction &matchedAccessFunction = + matchedAccess.getAccessFunction(); + + IndexSet matchedVariableIndices = matchedAccessFunction.map(equationIndices); + + for (const VariableAccess &access : accesses) { + if (access.getVariable() != matchedAccess.getVariable()) { + continue; } - for (auto& induction : inductions) { - induction.type = parser.getBuilder().getIndexType(); - } + const AccessFunction &accessFunction = access.getAccessFunction(); + + IndexSet accessedVariableIndices = accessFunction.map(equationIndices); - if (parser.parseRegion(*bodyRegion, inductions)) { - return mlir::failure(); + if (matchedVariableIndices.empty() && accessedVariableIndices.empty()) { + result.push_back(access); + } else if (matchedVariableIndices.overlaps(accessedVariableIndices)) { + result.push_back(access); } + } - if (bodyRegion->empty()) { - mlir::OpBuilder builder(bodyRegion); + return mlir::success(); +} - llvm::SmallVector argTypes( - inductions.size(), builder.getIndexType()); +mlir::LogicalResult EquationTemplateOp::getReadAccesses( + llvm::SmallVectorImpl &result, + const IndexSet &equationIndices, llvm::ArrayRef accesses, + const VariableAccess &matchedAccess) { + const AccessFunction &matchedAccessFunction = + matchedAccess.getAccessFunction(); - llvm::SmallVector argLocs( - inductions.size(), builder.getUnknownLoc()); + IndexSet matchedVariableIndices = matchedAccessFunction.map(equationIndices); - builder.createBlock(bodyRegion, {}, argTypes, argLocs); - } + for (const VariableAccess &access : accesses) { + if (access.getVariable() != matchedAccess.getVariable()) { + result.push_back(access); + } else { + const AccessFunction &accessFunction = access.getAccessFunction(); - result.addTypes(EquationType::get(parser.getContext())); - return mlir::success(); - } + IndexSet accessedVariableIndices = accessFunction.map(equationIndices); - void EquationTemplateOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer << "inductions = ["; - printer.printOperands(getInductionVariables()); - printer << "]"; - printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - printer << " "; - printer.printRegion(getBodyRegion(), false); + if (!matchedVariableIndices.empty() && !accessedVariableIndices.empty()) { + if (!matchedVariableIndices.contains(accessedVariableIndices)) { + result.push_back(access); + } + } + } } - void EquationTemplateOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); - } + return mlir::success(); +} - mlir::Block* EquationTemplateOp::createBody(unsigned int numOfInductions) - { - mlir::OpBuilder builder(getContext()); +mlir::Value EquationTemplateOp::getValueAtPath(const EquationPath &path) { + mlir::Block *bodyBlock = getBody(); + EquationPath::EquationSide side = path.getEquationSide(); - llvm::SmallVector argTypes( - numOfInductions, builder.getIndexType()); + auto equationSidesOp = + mlir::cast(bodyBlock->getTerminator()); - llvm::SmallVector argLocs( - numOfInductions, builder.getUnknownLoc()); + mlir::Value value = side == EquationPath::LEFT + ? equationSidesOp.getLhsValues()[path[0]] + : equationSidesOp.getRhsValues()[path[0]]; - return builder.createBlock(&getBodyRegion(), {}, argTypes, argLocs); + for (size_t i = 1, e = path.size(); i < e; ++i) { + mlir::Operation *op = value.getDefiningOp(); + assert(op != nullptr && "Invalid equation path"); + auto expressionInt = mlir::cast(op); + value = expressionInt.getExpressionElement(path[i]); } - void EquationTemplateOp::printInline(llvm::raw_ostream& os) - { - llvm::DenseMap inductions; - auto inductionVars = getInductionVariables(); + return value; +} - for (size_t i = 0, e = inductionVars.size(); i < e; ++i) { - inductions[inductionVars[i]] = static_cast(i); - } +std::optional +EquationTemplateOp::getAccessAtPath(mlir::SymbolTableCollection &symbolTable, + const EquationPath &path) { + // Get the induction variables and number them. + mlir::ValueRange inductionVariables = getInductionVariables(); + llvm::DenseMap inductionsPositionMap; - if (auto expressionOp = mlir::cast( - getBody()->getTerminator())) { - expressionOp.printExpression(os, inductions); - } + for (auto inductionVariable : llvm::enumerate(inductionVariables)) { + inductionsPositionMap[inductionVariable.value()] = + inductionVariable.index(); } - mlir::ValueRange EquationTemplateOp::getInductionVariables() - { - return getBodyRegion().getArguments(); - } + // Get the access. + llvm::SmallVector accesses; + mlir::Value access = getValueAtPath(path); - llvm::SmallVector - EquationTemplateOp::getInductionVariablesAtPath(const EquationPath& path) - { - llvm::SmallVector result; - auto equationInductions = getInductionVariables(); - result.append(equationInductions.begin(), equationInductions.end()); + if (mlir::failed(searchAccesses(accesses, symbolTable, inductionsPositionMap, + access, path))) { + return std::nullopt; + } - mlir::Block* bodyBlock = getBody(); - EquationPath::EquationSide side = path.getEquationSide(); + assert(accesses.size() == 1); + return accesses[0]; +} - auto equationSidesOp = - mlir::cast(bodyBlock->getTerminator()); +mlir::LogicalResult EquationTemplateOp::searchAccesses( + llvm::SmallVectorImpl &accesses, + mlir::SymbolTableCollection &symbolTable, + llvm::DenseMap &explicitInductionsPositionMap, + mlir::Value value, EquationPath path) { + mlir::Operation *definingOp = value.getDefiningOp(); - mlir::Value value = side == EquationPath::LEFT - ? equationSidesOp.getLhsValues()[path[0]] - : equationSidesOp.getRhsValues()[path[0]]; + if (!definingOp) { + return mlir::success(); + } - for (size_t i = 1, e = path.size(); i < e; ++i) { - mlir::Operation* op = value.getDefiningOp(); - assert(op != nullptr && "Invalid equation path"); - auto expressionInt = mlir::cast(op); - auto additionalInductions = expressionInt.getAdditionalInductions(); - result.append(additionalInductions.begin(), additionalInductions.end()); - value = expressionInt.getExpressionElement(path[i]); - } + AdditionalInductions additionalInductions; + llvm::SmallVector, 10> dimensionAccesses; - return result; + if (auto expressionInt = + mlir::dyn_cast(definingOp)) { + return expressionInt.getEquationAccesses( + accesses, symbolTable, explicitInductionsPositionMap, + additionalInductions, dimensionAccesses, std::move(path)); } - llvm::DenseMap - EquationTemplateOp::getInductionsPositionMap() - { - mlir::ValueRange inductionVariables = getInductionVariables(); - llvm::DenseMap inductionsPositionMap; + return mlir::failure(); +} + +mlir::LogicalResult EquationTemplateOp::cloneWithReplacedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + llvm::SmallVectorImpl> &results) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - for (auto inductionVariable : llvm::enumerate(inductionVariables)) { - inductionsPositionMap[inductionVariable.value()] = - inductionVariable.index(); + // Erase the operations in case of unrecoverable failure. + auto cleanOnFailure = llvm::make_scope_exit([&]() { + for (const auto &result : results) { + rewriter.eraseOp(result.second); } + }); - return inductionsPositionMap; + // The set of indices that are yet to be processed. + IndexSet remainingEquationIndices; + + if (equationIndices) { + remainingEquationIndices = equationIndices->get(); } - mlir::LogicalResult EquationTemplateOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); + // Determine the access functions. + mlir::Value destinationValue = getValueAtPath(access.getPath()); + int64_t destinationRank = 0; - // Get the induction variables and number them. - auto inductionsPositionMap = getInductionsPositionMap(); + if (auto destinationShapedType = + destinationValue.getType().dyn_cast()) { + destinationRank = destinationShapedType.getRank(); + } - // Search the accesses starting from the left-hand side of the equation. - if (mlir::failed(searchAccesses( - result, symbolTable, inductionsPositionMap, - equationSidesOp.getLhsValues()[0], - EquationPath(EquationPath::LEFT, 0)))) { - return mlir::failure(); - } + mlir::Value sourceValue = + replacementEquation.getValueAtPath(replacementAccess.getPath()); - // Search the accesses starting from the right-hand side of the equation. - if (mlir::failed(searchAccesses( - result, symbolTable, inductionsPositionMap, - equationSidesOp.getRhsValues()[0], - EquationPath(EquationPath::RIGHT, 0)))) { - return mlir::failure(); - } + int64_t sourceRank = 0; - return mlir::success(); + if (auto sourceShapedType = + sourceValue.getType().dyn_cast()) { + sourceRank = sourceShapedType.getRank(); + } + + if (destinationRank > sourceRank) { + // The access to be replaced requires indices of the variables that are + // potentially not handled by the source equation. + return mlir::failure(); } - mlir::LogicalResult EquationTemplateOp::getWriteAccesses( - llvm::SmallVectorImpl& result, - const IndexSet& equationIndices, - llvm::ArrayRef accesses, - const VariableAccess& matchedAccess) - { - const AccessFunction& matchedAccessFunction = - matchedAccess.getAccessFunction(); + auto destinationAccessFunction = access.getAccessFunction().clone(); - IndexSet matchedVariableIndices = - matchedAccessFunction.map(equationIndices); + // The extra subscription indices to be applied to the replacement value. + llvm::SmallVector additionalSubscriptionIndices; - for (const VariableAccess& access : accesses) { - if (access.getVariable() != matchedAccess.getVariable()) { - continue; - } + if (destinationRank < sourceRank) { + // The access to be replaced specifies more indices than the ones given + // by the source equation. This means that the source equation writes to + // more indices than the requested ones. Inlining the source equation + // results in possibly wasted additional computations, but does lead to + // a correct result. - const AccessFunction& accessFunction = access.getAccessFunction(); + auto destinationDimensionAccesses = + destinationAccessFunction->getGeneralizedAccesses(); - IndexSet accessedVariableIndices = - accessFunction.map(equationIndices); + destinationAccessFunction = + AccessFunction::build(destinationAccessFunction->getContext(), + destinationAccessFunction->getNumOfDims(), + llvm::ArrayRef(destinationDimensionAccesses) + .drop_back(sourceRank - destinationRank)); - if (matchedVariableIndices.empty() && accessedVariableIndices.empty()) { - result.push_back(access); - } else if (matchedVariableIndices.overlaps(accessedVariableIndices)) { - result.push_back(access); - } - } + // If the destination access has more indices than the source one, + // then collect the additional ones and apply them to the + // replacement value. + int64_t rankDifference = sourceRank - destinationRank; + mlir::Operation *replacedValueOp = destinationValue.getDefiningOp(); - return mlir::success(); - } + auto allAdditionalIndicesCollected = [&]() -> bool { + return rankDifference == + static_cast(additionalSubscriptionIndices.size()); + }; - mlir::LogicalResult EquationTemplateOp::getReadAccesses( - llvm::SmallVectorImpl& result, - const IndexSet& equationIndices, - llvm::ArrayRef accesses, - const VariableAccess& matchedAccess) - { - const AccessFunction& matchedAccessFunction = - matchedAccess.getAccessFunction(); + while (mlir::isa(replacedValueOp) && + !allAdditionalIndicesCollected()) { + if (auto extractOp = mlir::dyn_cast(replacedValueOp)) { + size_t numOfIndices = extractOp.getIndices().size(); - IndexSet matchedVariableIndices = - matchedAccessFunction.map(equationIndices); + for (size_t i = 0; i < numOfIndices && !allAdditionalIndicesCollected(); + ++i) { + additionalSubscriptionIndices.push_back( + extractOp.getIndices()[numOfIndices - i - 1]); + } - for (const VariableAccess& access : accesses) { - if (access.getVariable() != matchedAccess.getVariable()) { - result.push_back(access); - } else { - const AccessFunction& accessFunction = - access.getAccessFunction(); + replacedValueOp = extractOp.getTensor().getDefiningOp(); + continue; + } - IndexSet accessedVariableIndices = - accessFunction.map(equationIndices); + if (auto viewOp = mlir::dyn_cast(replacedValueOp)) { + size_t numOfSubscripts = viewOp.getSubscriptions().size(); - if (!matchedVariableIndices.empty() && - !accessedVariableIndices.empty()) { - if (!matchedVariableIndices.contains(accessedVariableIndices)) { - result.push_back(access); - } + for (size_t i = 0; + i < numOfSubscripts && !allAdditionalIndicesCollected(); ++i) { + additionalSubscriptionIndices.push_back( + viewOp.getSubscriptions()[numOfSubscripts - i - 1]); } + + replacedValueOp = viewOp.getSource().getDefiningOp(); + continue; } + + return mlir::failure(); } - return mlir::success(); - } + assert(allAdditionalIndicesCollected()); - mlir::Value EquationTemplateOp::getValueAtPath(const EquationPath& path) - { - mlir::Block* bodyBlock = getBody(); - EquationPath::EquationSide side = path.getEquationSide(); + // Indices have been collected in reverse order, due to the bottom-up + // visit of the operations tree. + std::reverse(additionalSubscriptionIndices.begin(), + additionalSubscriptionIndices.end()); + } - auto equationSidesOp = - mlir::cast(bodyBlock->getTerminator()); + VariableAccess destinationAccess(access.getPath(), access.getVariable(), + std::move(destinationAccessFunction)); - mlir::Value value = side == EquationPath::LEFT - ? equationSidesOp.getLhsValues()[path[0]] - : equationSidesOp.getRhsValues()[path[0]]; + // Try to perform a vectorized replacement first. + if (mlir::failed(cloneWithReplacedVectorizedAccess( + rewriter, equationIndices, access, replacementEquation, + replacementAccess, additionalSubscriptionIndices, results, + remainingEquationIndices))) { + return mlir::failure(); + } - for (size_t i = 1, e = path.size(); i < e; ++i) { - mlir::Operation* op = value.getDefiningOp(); - assert(op != nullptr && "Invalid equation path"); - auto expressionInt = mlir::cast(op); - value = expressionInt.getExpressionElement(path[i]); - } + // Perform scalar replacements on the remaining equation indices. + // TODO + // for (Point scalarEquationIndices : remainingEquationIndices) { + //} - return value; + if (remainingEquationIndices.empty()) { + cleanOnFailure.release(); + return mlir::success(); } - std::optional EquationTemplateOp::getAccessAtPath( - mlir::SymbolTableCollection& symbolTable, - const EquationPath& path) - { - // Get the induction variables and number them. - mlir::ValueRange inductionVariables = getInductionVariables(); - llvm::DenseMap inductionsPositionMap; + return mlir::failure(); +} - for (auto inductionVariable : llvm::enumerate(inductionVariables)) { - inductionsPositionMap[inductionVariable.value()] = - inductionVariable.index(); - } +mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + llvm::ArrayRef additionalSubscriptions, + llvm::SmallVectorImpl> &results, + IndexSet &remainingEquationIndices) { + const AccessFunction &destinationAccessFunction = access.getAccessFunction(); - // Get the access. - llvm::SmallVector accesses; - mlir::Value access = getValueAtPath(path); + const AccessFunction &sourceAccessFunction = + replacementAccess.getAccessFunction(); - if (mlir::failed(searchAccesses( - accesses, symbolTable, inductionsPositionMap, access, path))) { - return std::nullopt; - } + auto transformation = getReplacementTransformationAccess( + destinationAccessFunction, sourceAccessFunction); - assert(accesses.size() == 1); - return accesses[0]; + if (transformation) { + if (mlir::failed(cloneWithReplacedVectorizedAccess( + rewriter, equationIndices, access, replacementEquation, + replacementAccess, *transformation, additionalSubscriptions, + results, remainingEquationIndices))) { + return mlir::failure(); + } } - mlir::LogicalResult EquationTemplateOp::searchAccesses( - llvm::SmallVectorImpl& accesses, - mlir::SymbolTableCollection& symbolTable, - llvm::DenseMap& explicitInductionsPositionMap, - mlir::Value value, - EquationPath path) - { - mlir::Operation* definingOp = value.getDefiningOp(); + return mlir::success(); +} - if (!definingOp) { - return mlir::success(); +mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + const AccessFunction &transformation, + llvm::ArrayRef additionalSubscriptions, + llvm::SmallVectorImpl> &results, + IndexSet &remainingEquationIndices) { + if (equationIndices && !equationIndices->get().empty()) { + for (const MultidimensionalRange &range : + llvm::make_range(equationIndices->get().rangesBegin(), + equationIndices->get().rangesEnd())) { + if (mlir::failed(cloneWithReplacedVectorizedAccess( + rewriter, std::reference_wrapper(range), access, + replacementEquation, replacementAccess, transformation, + additionalSubscriptions, results, remainingEquationIndices))) { + return mlir::failure(); + } } - AdditionalInductions additionalInductions; - llvm::SmallVector, 10> dimensionAccesses; + return mlir::success(); + } - if (auto expressionInt = - mlir::dyn_cast(definingOp)) { - return expressionInt.getEquationAccesses( - accesses, symbolTable, explicitInductionsPositionMap, - additionalInductions, dimensionAccesses, - std::move(path)); - } + return cloneWithReplacedVectorizedAccess( + rewriter, + std::optional>( + std::nullopt), + access, replacementEquation, replacementAccess, transformation, + additionalSubscriptions, results, remainingEquationIndices); +} - return mlir::failure(); - } +mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( + mlir::RewriterBase &rewriter, + std::optional> + equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + const AccessFunction &transformation, + llvm::ArrayRef additionalSubscriptions, + llvm::SmallVectorImpl> &results, + IndexSet &remainingEquationIndices) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(getOperation()); + mlir::IRMapping mapping; - mlir::LogicalResult EquationTemplateOp::cloneWithReplacedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - llvm::SmallVectorImpl>& results) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); + // Create the equation template. + auto newEquationTemplateOp = rewriter.create(getLoc()); + newEquationTemplateOp->setAttrs(getOperation()->getAttrDictionary()); - // Erase the operations in case of unrecoverable failure. - auto cleanOnFailure = llvm::make_scope_exit([&]() { - for (const auto& result : results) { - rewriter.eraseOp(result.second); - } - }); + if (equationIndices) { + remainingEquationIndices -= equationIndices->get(); + results.emplace_back(IndexSet(equationIndices->get()), + newEquationTemplateOp); + } else { + results.emplace_back(IndexSet(), newEquationTemplateOp); + } - // The set of indices that are yet to be processed. - IndexSet remainingEquationIndices; + mlir::Block *newEquationBodyBlock = + newEquationTemplateOp.createBody(getInductionVariables().size()); - if (equationIndices) { - remainingEquationIndices = equationIndices->get(); - } + rewriter.setInsertionPointToStart(newEquationBodyBlock); - // Determine the access functions. - mlir::Value destinationValue = getValueAtPath(access.getPath()); - int64_t destinationRank = 0; + // The optional additional subscription indices. + llvm::SmallVector additionalMappedSubscriptions; - if (auto destinationShapedType = - destinationValue.getType().dyn_cast()) { - destinationRank = destinationShapedType.getRank(); - } + // Clone the operations composing the destination equation. + for (auto [oldInduction, newInduction] : + llvm::zip(getInductionVariables(), + newEquationTemplateOp.getInductionVariables())) { + mapping.map(oldInduction, newInduction); + } - mlir::Value sourceValue = - replacementEquation.getValueAtPath(replacementAccess.getPath()); + for (auto &op : getOps()) { + rewriter.clone(op, mapping); + } - int64_t sourceRank = 0; + mlir::Value originalReplacedValue = getValueAtPath(access.getPath()); + mlir::Value mappedReplacedValue = mapping.lookup(originalReplacedValue); + rewriter.setInsertionPointAfterValue(mappedReplacedValue); - if (auto sourceShapedType = - sourceValue.getType().dyn_cast()) { - sourceRank = sourceShapedType.getRank(); - } + // Clone the operations composing the replacement equation. + if (mlir::failed(mapInductionVariables( + rewriter, replacementEquation.getLoc(), mapping, replacementEquation, + newEquationTemplateOp, access.getPath(), transformation))) { + return mlir::failure(); + } - if (destinationRank > sourceRank) { - // The access to be replaced requires indices of the variables that are - // potentially not handled by the source equation. - return mlir::failure(); + for (auto &replacementOp : replacementEquation.getOps()) { + if (!mlir::isa(replacementOp)) { + rewriter.clone(replacementOp, mapping); } + } - auto destinationAccessFunction = access.getAccessFunction().clone(); - - // The extra subscription indices to be applied to the replacement value. - llvm::SmallVector additionalSubscriptionIndices; - - if (destinationRank < sourceRank) { - // The access to be replaced specifies more indices than the ones given - // by the source equation. This means that the source equation writes to - // more indices than the requested ones. Inlining the source equation - // results in possibly wasted additional computations, but does lead to - // a correct result. - - auto destinationDimensionAccesses = - destinationAccessFunction->getGeneralizedAccesses(); - - destinationAccessFunction = AccessFunction::build( - destinationAccessFunction->getContext(), - destinationAccessFunction->getNumOfDims(), - llvm::ArrayRef(destinationDimensionAccesses).drop_back( - sourceRank - destinationRank)); - - // If the destination access has more indices than the source one, - // then collect the additional ones and apply them to the - // replacement value. - int64_t rankDifference = sourceRank - destinationRank; - mlir::Operation* replacedValueOp = destinationValue.getDefiningOp(); - - auto allAdditionalIndicesCollected = [&]() -> bool { - return rankDifference == - static_cast(additionalSubscriptionIndices.size()); - }; - - while (mlir::isa(replacedValueOp) && - !allAdditionalIndicesCollected()) { - if (auto extractOp = - mlir::dyn_cast(replacedValueOp)) { - size_t numOfIndices = extractOp.getIndices().size(); - - for (size_t i = 0; i < numOfIndices && - !allAdditionalIndicesCollected(); ++i) { - additionalSubscriptionIndices.push_back( - extractOp.getIndices()[numOfIndices - i - 1]); - } - - replacedValueOp = extractOp.getTensor().getDefiningOp(); - continue; - } + // Get the replacement value. + mlir::Value replacement = mapping.lookup(replacementEquation.getValueAtPath( + EquationPath(EquationPath::RIGHT, replacementAccess.getPath()[0]))); - if (auto viewOp = mlir::dyn_cast(replacedValueOp)) { - size_t numOfSubscripts = viewOp.getSubscriptions().size(); + rewriter.replaceAllUsesWith(mappedReplacedValue, replacement); + return mlir::success(); +} - for (size_t i = 0; i < numOfSubscripts && - !allAdditionalIndicesCollected(); ++i) { - additionalSubscriptionIndices.push_back( - viewOp.getSubscriptions()[numOfSubscripts - i - 1]); - } +std::unique_ptr +EquationTemplateOp::getReplacementTransformationAccess( + const AccessFunction &destinationAccess, + const AccessFunction &sourceAccess) { + if (auto sourceInverseAccess = sourceAccess.inverse()) { + return destinationAccess.combine(*sourceInverseAccess); + } - replacedValueOp = viewOp.getSource().getDefiningOp(); - continue; - } + // Check if the source access is invertible by removing the constant + // accesses. - return mlir::failure(); - } + if (!sourceAccess.isAffine() || !destinationAccess.isAffine()) { + return nullptr; + } - assert(allAdditionalIndicesCollected()); + // Determine the constant results to be removed. + mlir::AffineMap sourceAffineMap = sourceAccess.getAffineMap(); + llvm::SmallVector constantExprPositions; - // Indices have been collected in reverse order, due to the bottom-up - // visit of the operations tree. - std::reverse(additionalSubscriptionIndices.begin(), - additionalSubscriptionIndices.end()); + for (size_t i = 0, e = sourceAffineMap.getNumResults(); i < e; ++i) { + if (mlir::isa(sourceAffineMap.getResult(i))) { + constantExprPositions.push_back(i); } + } - VariableAccess destinationAccess( - access.getPath(), access.getVariable(), - std::move(destinationAccessFunction)); + // Compute the reduced access functions. + auto reducedSourceAccessFunction = + AccessFunction::build(mlir::compressUnusedDims( + sourceAccess.getAffineMap().dropResults(constantExprPositions))); - // Try to perform a vectorized replacement first. - if (mlir::failed(cloneWithReplacedVectorizedAccess( - rewriter, equationIndices, access, replacementEquation, - replacementAccess, additionalSubscriptionIndices, results, - remainingEquationIndices))) { - return mlir::failure(); - } + auto reducedSourceInverseAccessFunction = + reducedSourceAccessFunction->inverse(); - // Perform scalar replacements on the remaining equation indices. - // TODO - //for (Point scalarEquationIndices : remainingEquationIndices) { - //} + if (!reducedSourceInverseAccessFunction) { + return nullptr; + } - if (remainingEquationIndices.empty()) { - cleanOnFailure.release(); - return mlir::success(); - } + auto reducedDestinationAccessFunction = AccessFunction::build( + destinationAccess.getAffineMap().dropResults(constantExprPositions)); + + auto combinedReducedAccess = reducedDestinationAccessFunction->combine( + *reducedSourceInverseAccessFunction); + + mlir::AffineMap combinedAffineMap = combinedReducedAccess->getAffineMap(); + return AccessFunction::build(mlir::AffineMap::get( + destinationAccess.getNumOfDims(), 0, combinedAffineMap.getResults(), + combinedAffineMap.getContext())); +} + +mlir::LogicalResult EquationTemplateOp::mapInductionVariables( + mlir::OpBuilder &builder, mlir::Location loc, mlir::IRMapping &mapping, + EquationTemplateOp source, EquationTemplateOp destination, + const EquationPath &destinationPath, const AccessFunction &transformation) { + if (!transformation.isAffine()) { return mlir::failure(); } - mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - llvm::ArrayRef additionalSubscriptions, - llvm::SmallVectorImpl< - std::pair>& results, - IndexSet& remainingEquationIndices) - { - const AccessFunction& destinationAccessFunction = - access.getAccessFunction(); + mlir::AffineMap affineMap = transformation.getAffineMap(); - const AccessFunction& sourceAccessFunction = - replacementAccess.getAccessFunction(); + if (affineMap.getNumResults() < source.getInductionVariables().size()) { + return mlir::failure(); + } - auto transformation = getReplacementTransformationAccess( - destinationAccessFunction, sourceAccessFunction); + llvm::SmallVector affineMapResults; - if (transformation) { - if (mlir::failed(cloneWithReplacedVectorizedAccess( - rewriter, equationIndices, access, replacementEquation, - replacementAccess, *transformation, additionalSubscriptions, - results, remainingEquationIndices))) { - return mlir::failure(); - } - } + auto inductionVariables = + destination.getInductionVariablesAtPath(destinationPath); - return mlir::success(); + if (mlir::failed(materializeAffineMap( + builder, loc, affineMap, inductionVariables, affineMapResults))) { + return mlir::failure(); } - mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - const AccessFunction& transformation, - llvm::ArrayRef additionalSubscriptions, - llvm::SmallVectorImpl>& results, - IndexSet& remainingEquationIndices) - { - if (equationIndices && !equationIndices->get().empty()) { - for (const MultidimensionalRange& range : llvm::make_range( - equationIndices->get().rangesBegin(), - equationIndices->get().rangesEnd())) { - if (mlir::failed(cloneWithReplacedVectorizedAccess( - rewriter, std::reference_wrapper(range), access, - replacementEquation, replacementAccess, transformation, - additionalSubscriptions, results, remainingEquationIndices))) { - return mlir::failure(); - } - } + auto sourceInductionVariables = source.getInductionVariables(); - return mlir::success(); - } + for (size_t i = 0, e = sourceInductionVariables.size(); i < e; ++i) { + mapping.map(sourceInductionVariables[i], affineMapResults[i]); + } - return cloneWithReplacedVectorizedAccess( - rewriter, - std::optional< - std::reference_wrapper>(std::nullopt), - access, replacementEquation, replacementAccess, transformation, - additionalSubscriptions, results, remainingEquationIndices); - } - - mlir::LogicalResult EquationTemplateOp::cloneWithReplacedVectorizedAccess( - mlir::RewriterBase& rewriter, - std::optional< - std::reference_wrapper> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - const AccessFunction& transformation, - llvm::ArrayRef additionalSubscriptions, - llvm::SmallVectorImpl>& results, - IndexSet& remainingEquationIndices) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(getOperation()); - mlir::IRMapping mapping; - - // Create the equation template. - auto newEquationTemplateOp = rewriter.create(getLoc()); - newEquationTemplateOp->setAttrs(getOperation()->getAttrDictionary()); - - if (equationIndices) { - remainingEquationIndices -= equationIndices->get(); - results.emplace_back(IndexSet(equationIndices->get()), newEquationTemplateOp); - } else { - results.emplace_back(IndexSet(), newEquationTemplateOp); - } + return mlir::success(); +} - mlir::Block* newEquationBodyBlock = - newEquationTemplateOp.createBody(getInductionVariables().size()); +IndexSet EquationTemplateOp::applyAccessFunction( + const AccessFunction &accessFunction, + std::optional equationIndices, + const EquationPath &path) { + IndexSet result; - rewriter.setInsertionPointToStart(newEquationBodyBlock); + if (equationIndices) { + result = accessFunction.map(IndexSet(*equationIndices)); + } - // The optional additional subscription indices. - llvm::SmallVector additionalMappedSubscriptions; + return result; +} - // Clone the operations composing the destination equation. - for (auto [oldInduction, newInduction] : llvm::zip( - getInductionVariables(), - newEquationTemplateOp.getInductionVariables())) { - mapping.map(oldInduction, newInduction); - } +mlir::LogicalResult EquationTemplateOp::explicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + std::optional equationIndices, + const EquationPath &path) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - for (auto& op : getOps()) { - rewriter.clone(op, mapping); - } + // Get all the paths that lead to accesses with the same accessed variable + // and function. + auto requestedAccess = getAccessAtPath(symbolTableCollection, path); - mlir::Value originalReplacedValue = getValueAtPath(access.getPath()); - mlir::Value mappedReplacedValue = mapping.lookup(originalReplacedValue); - rewriter.setInsertionPointAfterValue(mappedReplacedValue); + if (!requestedAccess) { + return mlir::failure(); + } - // Clone the operations composing the replacement equation. - if (mlir::failed(mapInductionVariables( - rewriter, replacementEquation.getLoc(), - mapping, replacementEquation, newEquationTemplateOp, - access.getPath(), transformation))) { - return mlir::failure(); - } + const AccessFunction &requestedAccessFunction = + requestedAccess->getAccessFunction(); - for (auto& replacementOp : replacementEquation.getOps()) { - if (!mlir::isa(replacementOp)) { - rewriter.clone(replacementOp, mapping); - } - } + IndexSet requestedIndices = + applyAccessFunction(requestedAccessFunction, equationIndices, path); - // Get the replacement value. - mlir::Value replacement = - mapping.lookup(replacementEquation.getValueAtPath( - EquationPath(EquationPath::RIGHT, replacementAccess.getPath()[0]))); + llvm::SmallVector accesses; - rewriter.replaceAllUsesWith(mappedReplacedValue, replacement); - return mlir::success(); + if (mlir::failed(getAccesses(accesses, symbolTableCollection))) { + return mlir::failure(); } - std::unique_ptr - EquationTemplateOp::getReplacementTransformationAccess( - const AccessFunction& destinationAccess, - const AccessFunction& sourceAccess) - { - if (auto sourceInverseAccess = sourceAccess.inverse()) { - return destinationAccess.combine(*sourceInverseAccess); - } - - // Check if the source access is invertible by removing the constant - // accesses. + llvm::SmallVector filteredAccesses; - if (!sourceAccess.isAffine() || !destinationAccess.isAffine()) { - return nullptr; + for (const VariableAccess &access : accesses) { + if (requestedAccess->getVariable() != access.getVariable()) { + continue; } - // Determine the constant results to be removed. - mlir::AffineMap sourceAffineMap = sourceAccess.getAffineMap(); - llvm::SmallVector constantExprPositions; + const AccessFunction ¤tAccessFunction = access.getAccessFunction(); - for (size_t i = 0, e = sourceAffineMap.getNumResults(); i < e; ++i) { - if (mlir::isa(sourceAffineMap.getResult(i))) { - constantExprPositions.push_back(i); - } - } + IndexSet currentIndices = applyAccessFunction( + currentAccessFunction, equationIndices, access.getPath()); - // Compute the reduced access functions. - auto reducedSourceAccessFunction = - AccessFunction::build(mlir::compressUnusedDims( - sourceAccess.getAffineMap().dropResults(constantExprPositions))); + if (requestedIndices == currentIndices) { + filteredAccesses.push_back(access); + } + } - auto reducedSourceInverseAccessFunction = - reducedSourceAccessFunction->inverse(); + assert(!filteredAccesses.empty()); - if (!reducedSourceInverseAccessFunction) { - return nullptr; - } + // If there is only one access, then it is sufficient to follow the path + // and invert the operations. - auto reducedDestinationAccessFunction = AccessFunction::build( - destinationAccess.getAffineMap().dropResults(constantExprPositions)); + auto terminator = mlir::cast(getBody()->getTerminator()); - auto combinedReducedAccess = reducedDestinationAccessFunction->combine( - *reducedSourceInverseAccessFunction); + auto lhsOp = terminator.getLhs().getDefiningOp(); + auto rhsOp = terminator.getRhs().getDefiningOp(); - mlir::AffineMap combinedAffineMap = combinedReducedAccess->getAffineMap(); + rewriter.setInsertionPoint(lhsOp); - return AccessFunction::build(mlir::AffineMap::get( - destinationAccess.getNumOfDims(), 0, - combinedAffineMap.getResults(), combinedAffineMap.getContext())); + if (rhsOp->isBeforeInBlock(lhsOp)) { + rewriter.setInsertionPoint(rhsOp); } - mlir::LogicalResult EquationTemplateOp::mapInductionVariables( - mlir::OpBuilder& builder, - mlir::Location loc, - mlir::IRMapping& mapping, - EquationTemplateOp source, - EquationTemplateOp destination, - const EquationPath& destinationPath, - const AccessFunction& transformation) - { - if (!transformation.isAffine()) { - return mlir::failure(); + if (filteredAccesses.size() == 1) { + for (size_t i = 1, e = path.size(); i < e; ++i) { + if (mlir::failed( + explicitateLeaf(rewriter, path[i], path.getEquationSide()))) { + return mlir::failure(); + } } - mlir::AffineMap affineMap = transformation.getAffineMap(); + if (path.getEquationSide() == EquationPath::RIGHT) { + llvm::SmallVector lhsValues; + llvm::SmallVector rhsValues; - if (affineMap.getNumResults() < source.getInductionVariables().size()) { - return mlir::failure(); - } + rewriter.setInsertionPointAfter(terminator); - llvm::SmallVector affineMapResults; + rewriter.create( + terminator->getLoc(), terminator.getRhs(), terminator.getLhs()); - auto inductionVariables = - destination.getInductionVariablesAtPath(destinationPath); + rewriter.eraseOp(terminator); + } + } else { + // If there are multiple accesses, then we must group all of them and + // extract the common multiplying factor. - if (mlir::failed(materializeAffineMap( - builder, loc, affineMap, inductionVariables, affineMapResults))) { + if (mlir::failed(groupLeftHandSide(rewriter, symbolTableCollection, + equationIndices, *requestedAccess))) { return mlir::failure(); } + } - auto sourceInductionVariables = source.getInductionVariables(); - - for (size_t i = 0, e = sourceInductionVariables.size(); i < e; ++i) { - mapping.map(sourceInductionVariables[i], affineMapResults[i]); - } + return mlir::success(); +} - return mlir::success(); - } +EquationTemplateOp EquationTemplateOp::cloneAndExplicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + std::optional equationIndices, + const EquationPath &path) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(getOperation()); - IndexSet EquationTemplateOp::applyAccessFunction( - const AccessFunction& accessFunction, - std::optional equationIndices, - const EquationPath& path) - { - IndexSet result; + auto clonedOp = + mlir::cast(rewriter.clone(*getOperation())); - if (equationIndices) { - result = accessFunction.map(IndexSet(*equationIndices)); - } + auto cleanOnFailure = + llvm::make_scope_exit([&]() { rewriter.eraseOp(clonedOp); }); - return result; + if (mlir::failed(clonedOp.explicitate(rewriter, symbolTableCollection, + equationIndices, path))) { + return nullptr; } - mlir::LogicalResult EquationTemplateOp::explicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - std::optional equationIndices, - const EquationPath& path) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - - // Get all the paths that lead to accesses with the same accessed variable - // and function. - auto requestedAccess = getAccessAtPath(symbolTableCollection, path); + cleanOnFailure.release(); + return clonedOp; +} - if (!requestedAccess) { - return mlir::failure(); - } +mlir::LogicalResult +EquationTemplateOp::explicitateLeaf(mlir::RewriterBase &rewriter, + size_t argumentIndex, + EquationPath::EquationSide side) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - const AccessFunction& requestedAccessFunction = - requestedAccess->getAccessFunction(); + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - IndexSet requestedIndices = applyAccessFunction( - requestedAccessFunction, - equationIndices, - path); + mlir::Value oldLhsValue = equationSidesOp.getLhsValues()[0]; + mlir::Value oldRhsValue = equationSidesOp.getRhsValues()[0]; - llvm::SmallVector accesses; + mlir::Value toExplicitate = + side == EquationPath::LEFT ? oldLhsValue : oldRhsValue; - if (mlir::failed(getAccesses(accesses, symbolTableCollection))) { - return mlir::failure(); - } + mlir::Value otherExp = + side == EquationPath::RIGHT ? oldLhsValue : oldRhsValue; - llvm::SmallVector filteredAccesses; + mlir::Operation *op = toExplicitate.getDefiningOp(); + auto invertibleOp = mlir::dyn_cast(op); - for (const VariableAccess& access : accesses) { - if (requestedAccess->getVariable() != access.getVariable()) { - continue; - } + if (!invertibleOp) { + return mlir::failure(); + } - const AccessFunction& currentAccessFunction = access.getAccessFunction(); + rewriter.setInsertionPoint(invertibleOp); - IndexSet currentIndices = applyAccessFunction( - currentAccessFunction, - equationIndices, - access.getPath()); + if (auto otherExpOp = otherExp.getDefiningOp(); + otherExpOp && invertibleOp->isBeforeInBlock(otherExpOp)) { + rewriter.setInsertionPointAfter(otherExpOp); + } - if (requestedIndices == currentIndices) { - filteredAccesses.push_back(access); - } - } + mlir::Value invertedOpResult = + invertibleOp.inverse(rewriter, argumentIndex, otherExp); - assert(!filteredAccesses.empty()); + if (!invertedOpResult) { + return mlir::failure(); + } - // If there is only one access, then it is sufficient to follow the path - // and invert the operations. + llvm::SmallVector newLhsValues; + llvm::SmallVector newRhsValues; - auto terminator = - mlir::cast(getBody()->getTerminator()); + if (side == EquationPath::LEFT) { + newLhsValues.push_back(op->getOperand(argumentIndex)); + } else { + newLhsValues.push_back(invertedOpResult); + } - auto lhsOp = terminator.getLhs().getDefiningOp(); - auto rhsOp = terminator.getRhs().getDefiningOp(); + if (side == EquationPath::LEFT) { + newRhsValues.push_back(invertedOpResult); + } else { + newRhsValues.push_back(op->getOperand(argumentIndex)); + } - rewriter.setInsertionPoint(lhsOp); + // Create the new terminator. + rewriter.setInsertionPoint(equationSidesOp); - if (rhsOp->isBeforeInBlock(lhsOp)) { - rewriter.setInsertionPoint(rhsOp); - } + auto oldLhs = + mlir::cast(equationSidesOp.getLhs().getDefiningOp()); - if (filteredAccesses.size() == 1) { - for (size_t i = 1, e = path.size(); i < e; ++i) { - if (mlir::failed(explicitateLeaf( - rewriter, path[i], path.getEquationSide()))) { - return mlir::failure(); - } - } + auto oldRhs = + mlir::cast(equationSidesOp.getRhs().getDefiningOp()); - if (path.getEquationSide() == EquationPath::RIGHT) { - llvm::SmallVector lhsValues; - llvm::SmallVector rhsValues; + rewriter.replaceOpWithNewOp(oldLhs, newLhsValues); + rewriter.replaceOpWithNewOp(oldRhs, newRhsValues); - rewriter.setInsertionPointAfter(terminator); + return mlir::success(); +} - rewriter.create( - terminator->getLoc(), terminator.getRhs(), terminator.getLhs()); +static mlir::LogicalResult removeSubtractions(mlir::RewriterBase &rewriter, + mlir::Operation *root) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::Operation *op = root; - rewriter.eraseOp(terminator); - } - } else { - // If there are multiple accesses, then we must group all of them and - // extract the common multiplying factor. + if (!op) { + return mlir::success(); + } - if (mlir::failed(groupLeftHandSide( - rewriter, symbolTableCollection, equationIndices, - *requestedAccess))) { + if (!mlir::isa(op) && !mlir::isa(op)) { + for (mlir::Value operand : op->getOperands()) { + if (mlir::failed(removeSubtractions(rewriter, operand.getDefiningOp()))) { return mlir::failure(); } } - - return mlir::success(); } - EquationTemplateOp EquationTemplateOp::cloneAndExplicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - std::optional equationIndices, - const EquationPath& path) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(getOperation()); - - auto clonedOp = - mlir::cast(rewriter.clone(*getOperation())); - - auto cleanOnFailure = llvm::make_scope_exit([&]() { - rewriter.eraseOp(clonedOp); - }); + if (auto subOp = mlir::dyn_cast(op)) { + rewriter.setInsertionPoint(subOp); + mlir::Value rhs = subOp.getRhs(); - if (mlir::failed(clonedOp.explicitate( - rewriter, symbolTableCollection, equationIndices, path))) { - return nullptr; - } + mlir::Value negatedRhs = + rewriter.create(rhs.getLoc(), rhs.getType(), rhs); - cleanOnFailure.release(); - return clonedOp; + rewriter.replaceOpWithNewOp(subOp, subOp.getResult().getType(), + subOp.getLhs(), negatedRhs); } - mlir::LogicalResult EquationTemplateOp::explicitateLeaf( - mlir::RewriterBase& rewriter, - size_t argumentIndex, - EquationPath::EquationSide side) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); - - mlir::Value oldLhsValue = equationSidesOp.getLhsValues()[0]; - mlir::Value oldRhsValue = equationSidesOp.getRhsValues()[0]; - - mlir::Value toExplicitate = side == EquationPath::LEFT - ? oldLhsValue : oldRhsValue; + return mlir::success(); +} - mlir::Value otherExp = side == EquationPath::RIGHT - ? oldLhsValue : oldRhsValue; +static mlir::LogicalResult distributeMulAndDivOps(mlir::RewriterBase &rewriter, + mlir::Operation *root) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::Operation *op = root; - mlir::Operation* op = toExplicitate.getDefiningOp(); - auto invertibleOp = mlir::dyn_cast(op); + if (!op) { + return mlir::success(); + } - if (!invertibleOp) { + for (auto operand : op->getOperands()) { + if (mlir::failed( + distributeMulAndDivOps(rewriter, operand.getDefiningOp()))) { return mlir::failure(); } + } - rewriter.setInsertionPoint(invertibleOp); + if (auto distributableOp = mlir::dyn_cast(op)) { + if (!mlir::isa(op)) { + rewriter.setInsertionPoint(distributableOp); + llvm::SmallVector results; - if (auto otherExpOp = otherExp.getDefiningOp(); - otherExpOp && invertibleOp->isBeforeInBlock(otherExpOp)) { - rewriter.setInsertionPointAfter(otherExpOp); + if (mlir::succeeded(distributableOp.distribute(results, rewriter))) { + for (size_t i = 0, e = distributableOp->getNumResults(); i < e; ++i) { + mlir::Value oldValue = distributableOp->getResult(i); + mlir::Value newValue = results[i]; + rewriter.replaceAllUsesWith(oldValue, newValue); + } + } } + } + + return mlir::success(); +} + +static mlir::LogicalResult pushNegateOps(mlir::RewriterBase &rewriter, + mlir::Operation *root) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::Operation *op = root; - mlir::Value invertedOpResult = - invertibleOp.inverse(rewriter, argumentIndex, otherExp); + if (!op) { + return mlir::success(); + } - if (!invertedOpResult) { + for (mlir::Value operand : op->getOperands()) { + if (mlir::failed(pushNegateOps(rewriter, operand.getDefiningOp()))) { return mlir::failure(); } + } - llvm::SmallVector newLhsValues; - llvm::SmallVector newRhsValues; + if (auto distributableOp = mlir::dyn_cast(op)) { + rewriter.setInsertionPoint(distributableOp); + llvm::SmallVector results; - if (side == EquationPath::LEFT) { - newLhsValues.push_back(op->getOperand(argumentIndex)); - } else { - newLhsValues.push_back(invertedOpResult); + if (mlir::succeeded(distributableOp.distribute(results, rewriter))) { + rewriter.replaceOp(distributableOp, results); } + } - if (side == EquationPath::LEFT) { - newRhsValues.push_back(invertedOpResult); - } else { - newRhsValues.push_back(op->getOperand(argumentIndex)); - } + return mlir::success(); +} + +mlir::LogicalResult EquationTemplateOp::collectSummedValues( + llvm::SmallVectorImpl> &result, + mlir::Value root, EquationPath path) { + if (auto definingOp = root.getDefiningOp()) { + if (auto addOp = mlir::dyn_cast(definingOp)) { + if (mlir::failed(collectSummedValues(result, addOp.getLhs(), path + 0))) { + return mlir::failure(); + } - // Create the new terminator. - rewriter.setInsertionPoint(equationSidesOp); + if (mlir::failed(collectSummedValues(result, addOp.getRhs(), path + 1))) { + return mlir::failure(); + } - auto oldLhs = mlir::cast( - equationSidesOp.getLhs().getDefiningOp()); + return mlir::success(); + } + } - auto oldRhs = mlir::cast( - equationSidesOp.getRhs().getDefiningOp()); + result.push_back(std::make_pair(root, path)); + return mlir::success(); +} - rewriter.replaceOpWithNewOp(oldLhs, newLhsValues); - rewriter.replaceOpWithNewOp(oldRhs, newRhsValues); +static void foldValue(mlir::RewriterBase &rewriter, mlir::Value value, + mlir::Block *block) { + mlir::OperationFolder helper(value.getContext()); + llvm::SmallVector visitStack; + llvm::SmallVector ops; + llvm::DenseSet processed; - return mlir::success(); + if (auto definingOp = value.getDefiningOp()) { + visitStack.push_back(definingOp); } - static mlir::LogicalResult removeSubtractions( - mlir::RewriterBase& rewriter, mlir::Operation* root) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - mlir::Operation* op = root; + while (!visitStack.empty()) { + auto op = visitStack.pop_back_val(); + ops.push_back(op); - if (!op) { - return mlir::success(); - } - - if (!mlir::isa(op) && !mlir::isa(op)) { - for (mlir::Value operand : op->getOperands()) { - if (mlir::failed(removeSubtractions( - rewriter, operand.getDefiningOp()))) { - return mlir::failure(); - } + for (const auto &operand : op->getOperands()) { + if (auto definingOp = operand.getDefiningOp()) { + visitStack.push_back(definingOp); } } + } - if (auto subOp = mlir::dyn_cast(op)) { - rewriter.setInsertionPoint(subOp); - mlir::Value rhs = subOp.getRhs(); - - mlir::Value negatedRhs = rewriter.create( - rhs.getLoc(), rhs.getType(), rhs); + llvm::SmallVector constants; - rewriter.replaceOpWithNewOp( - subOp, subOp.getResult().getType(), subOp.getLhs(), negatedRhs); + for (mlir::Operation *op : llvm::reverse(ops)) { + if (processed.contains(op)) { + continue; } - return mlir::success(); + processed.insert(op); + + if (mlir::failed(helper.tryToFold(op))) { + break; + } } - static mlir::LogicalResult distributeMulAndDivOps( - mlir::RewriterBase& rewriter, mlir::Operation* root) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - mlir::Operation* op = root; + for (auto *op : llvm::reverse(constants)) { + op->moveBefore(block, block->begin()); + } +} - if (!op) { - return mlir::success(); - } +static std::optional isZeroAttr(mlir::Attribute attribute) { + if (auto booleanAttr = attribute.dyn_cast()) { + return !booleanAttr.getValue(); + } - for (auto operand : op->getOperands()) { - if (mlir::failed(distributeMulAndDivOps( - rewriter, operand.getDefiningOp()))) { - return mlir::failure(); - } - } + if (auto integerAttr = attribute.dyn_cast()) { + return integerAttr.getValue() == 0; + } - if (auto distributableOp = mlir::dyn_cast(op)) { - if (!mlir::isa(op)) { - rewriter.setInsertionPoint(distributableOp); - llvm::SmallVector results; + if (auto realAttr = attribute.dyn_cast()) { + return realAttr.getValue().isZero(); + } - if (mlir::succeeded(distributableOp.distribute(results, rewriter))) { - for (size_t i = 0, e = distributableOp->getNumResults(); - i < e; ++i) { - mlir::Value oldValue = distributableOp->getResult(i); - mlir::Value newValue = results[i]; - rewriter.replaceAllUsesWith(oldValue, newValue); - } - } - } - } + if (auto integerAttr = attribute.cast()) { + return integerAttr.getValue() == 0; + } - return mlir::success(); + if (auto floatAttr = attribute.cast()) { + return floatAttr.getValueAsDouble() == 0; } - static mlir::LogicalResult pushNegateOps( - mlir::RewriterBase& rewriter, mlir::Operation* root) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - mlir::Operation* op = root; + return std::nullopt; +} + +std::optional> +EquationTemplateOp::getMultiplyingFactor( + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::DenseMap &inductionsPositionMap, + const IndexSet &equationIndices, mlir::Value value, + llvm::StringRef variable, const IndexSet &variableIndices, + EquationPath path) { + mlir::OpBuilder::InsertionGuard guard(builder); - if (!op) { - return mlir::success(); + auto isAccessToVarFn = [&](mlir::Value value, llvm::StringRef variable) { + mlir::Operation *definingOp = value.getDefiningOp(); + + if (!definingOp) { + return false; } - for (mlir::Value operand : op->getOperands()) { - if (mlir::failed(pushNegateOps(rewriter, operand.getDefiningOp()))) { - return mlir::failure(); + while (definingOp) { + if (auto op = mlir::dyn_cast(definingOp)) { + return op.getVariable() == variable; } - } - if (auto distributableOp = mlir::dyn_cast(op)) { - rewriter.setInsertionPoint(distributableOp); - llvm::SmallVector results; + if (auto op = mlir::dyn_cast(definingOp)) { + definingOp = op.getTensor().getDefiningOp(); + continue; + } - if (mlir::succeeded(distributableOp.distribute(results, rewriter))) { - rewriter.replaceOp(distributableOp, results); + if (auto op = mlir::dyn_cast(definingOp)) { + definingOp = op.getSource().getDefiningOp(); + continue; } - } - return mlir::success(); - } + return false; + } - mlir::LogicalResult EquationTemplateOp::collectSummedValues( - llvm::SmallVectorImpl>& result, - mlir::Value root, - EquationPath path) - { - if (auto definingOp = root.getDefiningOp()) { - if (auto addOp = mlir::dyn_cast(definingOp)) { - if (mlir::failed(collectSummedValues( - result, addOp.getLhs(), path + 0))) { - return mlir::failure(); - } + return false; + }; - if (mlir::failed(collectSummedValues( - result, addOp.getRhs(), path + 1))) { - return mlir::failure(); - } + if (isAccessToVarFn(value, variable)) { + llvm::SmallVector accesses; - return mlir::success(); - } + if (mlir::failed(searchAccesses(accesses, symbolTableCollection, + inductionsPositionMap, value, path)) || + accesses.size() != 1) { + return std::nullopt; } - result.push_back(std::make_pair(root, path)); - return mlir::success(); - } - - static void foldValue( - mlir::RewriterBase& rewriter, - mlir::Value value, - mlir::Block* block) - { - mlir::OperationFolder helper(value.getContext()); - llvm::SmallVector visitStack; - llvm::SmallVector ops; - llvm::DenseSet processed; + if (accesses[0].getVariable().getRootReference() == variable) { + const AccessFunction &accessFunction = accesses[0].getAccessFunction(); + auto accessedIndices = accessFunction.map(equationIndices); - if (auto definingOp = value.getDefiningOp()) { - visitStack.push_back(definingOp); - } + if (variableIndices == accessedIndices) { + if (auto constantMaterializableType = + value.getType() + .dyn_cast()) { - while (!visitStack.empty()) { - auto op = visitStack.pop_back_val(); - ops.push_back(op); + mlir::Value one = constantMaterializableType.materializeIntConstant( + builder, value.getLoc(), 1); - for (const auto& operand : op->getOperands()) { - if (auto definingOp = operand.getDefiningOp()) { - visitStack.push_back(definingOp); + return std::make_pair(static_cast(1), one); } + + return std::nullopt; } } + } - llvm::SmallVector constants; + mlir::Operation *op = value.getDefiningOp(); - for (mlir::Operation* op : llvm::reverse(ops)) { - if (processed.contains(op)) { - continue; - } + if (auto constantOp = mlir::dyn_cast(op)) { + return std::make_pair(static_cast(0), constantOp.getResult()); + } - processed.insert(op); + if (auto negateOp = mlir::dyn_cast(op)) { + auto operand = getMultiplyingFactor( + builder, symbolTableCollection, inductionsPositionMap, equationIndices, + negateOp.getOperand(), variable, variableIndices, path + 0); - if (mlir::failed(helper.tryToFold(op))) { - break; - } + if (!operand) { + return std::nullopt; } - for (auto* op : llvm::reverse(constants)) { - op->moveBefore(block, block->begin()); + if (!operand->second) { + return std::nullopt; } + + mlir::Value result = builder.create( + negateOp.getLoc(), negateOp.getResult().getType(), operand->second); + + return std::make_pair(operand->first, result); } - static std::optional isZeroAttr(mlir::Attribute attribute) - { - if (auto booleanAttr = attribute.dyn_cast()) { - return !booleanAttr.getValue(); - } + if (auto mulOp = mlir::dyn_cast(op)) { + auto lhs = getMultiplyingFactor( + builder, symbolTableCollection, inductionsPositionMap, equationIndices, + mulOp.getLhs(), variable, variableIndices, path + 0); - if (auto integerAttr = attribute.dyn_cast()) { - return integerAttr.getValue() == 0; - } + auto rhs = getMultiplyingFactor( + builder, symbolTableCollection, inductionsPositionMap, equationIndices, + mulOp.getRhs(), variable, variableIndices, path + 1); - if (auto realAttr = attribute.dyn_cast()) { - return realAttr.getValue().isZero(); + if (!lhs || !rhs) { + return std::nullopt; } - if (auto integerAttr = attribute.cast()) { - return integerAttr.getValue() == 0; + if (!lhs->second || !rhs->second) { + return std::make_pair(static_cast(0), mlir::Value()); } - if (auto floatAttr = attribute.cast()) { - return floatAttr.getValueAsDouble() == 0; - } + mlir::Value result = builder.create( + mulOp.getLoc(), mulOp.getResult().getType(), lhs->second, rhs->second); - return std::nullopt; + return std::make_pair(lhs->first + rhs->first, result); } - std::optional> - EquationTemplateOp::getMultiplyingFactor( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::DenseMap& inductionsPositionMap, - const IndexSet& equationIndices, - mlir::Value value, - llvm::StringRef variable, - const IndexSet& variableIndices, - EquationPath path) - { - mlir::OpBuilder::InsertionGuard guard(builder); + auto hasAccessToVar = [&](mlir::Value value, + EquationPath path) -> std::optional { + llvm::SmallVector accesses; - auto isAccessToVarFn = [&](mlir::Value value, llvm::StringRef variable) { - mlir::Operation* definingOp = value.getDefiningOp(); + if (mlir::failed(searchAccesses(accesses, symbolTableCollection, + inductionsPositionMap, value, path))) { + return std::nullopt; + } - if (!definingOp) { + bool hasAccess = llvm::any_of(accesses, [&](const VariableAccess &access) { + if (access.getVariable().getRootReference().getValue() != variable) { return false; } - while (definingOp) { - if (auto op = mlir::dyn_cast(definingOp)) { - return op.getVariable() == variable; - } - - if (auto op = mlir::dyn_cast(definingOp)) { - definingOp = op.getTensor().getDefiningOp(); - continue; - } - - if (auto op = mlir::dyn_cast(definingOp)) { - definingOp = op.getSource().getDefiningOp(); - continue; - } + const AccessFunction &accessFunction = access.getAccessFunction(); + IndexSet accessedIndices = accessFunction.map(equationIndices); - return false; + if (accessedIndices.empty() && variableIndices.empty()) { + return true; } - return false; - }; - - if (isAccessToVarFn(value, variable)) { - llvm::SmallVector accesses; - - if (mlir::failed(searchAccesses( - accesses, symbolTableCollection, inductionsPositionMap, - value, path)) || accesses.size() != 1) { - return std::nullopt; - } + return accessedIndices.overlaps(variableIndices); + }); - if (accesses[0].getVariable().getRootReference() == variable) { - const AccessFunction& accessFunction = accesses[0].getAccessFunction(); - auto accessedIndices = accessFunction.map(equationIndices); + if (hasAccess) { + return true; + } - if (variableIndices == accessedIndices) { - if (auto constantMaterializableType = - value.getType() - .dyn_cast()) { + return false; + }; - mlir::Value one = - constantMaterializableType.materializeIntConstant( - builder, value.getLoc(), 1); + if (auto divOp = mlir::dyn_cast(op)) { + auto dividend = getMultiplyingFactor( + builder, symbolTableCollection, inductionsPositionMap, equationIndices, + divOp.getLhs(), variable, variableIndices, path + 0); - return std::make_pair(static_cast(1), one); - } + if (!dividend) { + return std::nullopt; + } - return std::nullopt; - } - } + if (!dividend->second) { + return dividend; } - mlir::Operation* op = value.getDefiningOp(); + // Check that the right-hand side value has no access to the variable + // of interest. + auto rhsHasAccess = hasAccessToVar(divOp.getRhs(), path + 1); - if (auto constantOp = mlir::dyn_cast(op)) { - return std::make_pair( - static_cast(0), - constantOp.getResult()); + if (!rhsHasAccess || *rhsHasAccess) { + return std::nullopt; } - if (auto negateOp = mlir::dyn_cast(op)) { - auto operand = getMultiplyingFactor( - builder, symbolTableCollection, inductionsPositionMap, - equationIndices, negateOp.getOperand(), variable, variableIndices, - path + 0); - - if (!operand) { - return std::nullopt; - } + mlir::Value result = + builder.create(divOp.getLoc(), divOp.getResult().getType(), + dividend->second, divOp.getRhs()); - if (!operand->second) { - return std::nullopt; - } + return std::make_pair(dividend->first, result); + } - mlir::Value result = builder.create( - negateOp.getLoc(), negateOp.getResult().getType(), operand->second); + // Check that the value is not the result of an operation using the + // variable of interest. If it has such access, then we are not able to + // extract the multiplying factor. + if (hasAccessToVar(value, path)) { + return std::make_pair(static_cast(1), mlir::Value()); + } - return std::make_pair(operand->first, result); - } + return std::make_pair(static_cast(0), value); +} - if (auto mulOp = mlir::dyn_cast(op)) { - auto lhs = getMultiplyingFactor( - builder, symbolTableCollection, inductionsPositionMap, - equationIndices, mulOp.getLhs(), variable, variableIndices, - path + 0); +bool EquationTemplateOp::checkAccessEquivalence( + const IndexSet &equationIndices, const VariableAccess &firstAccess, + const VariableAccess &secondAccess) { + const AccessFunction &firstAccessFunction = firstAccess.getAccessFunction(); - auto rhs = getMultiplyingFactor( - builder, symbolTableCollection, inductionsPositionMap, - equationIndices, mulOp.getRhs(), variable, variableIndices, - path + 1); + const AccessFunction &secondAccessFunction = secondAccess.getAccessFunction(); - if (!lhs || !rhs) { - return std::nullopt; - } + IndexSet firstIndices = firstAccessFunction.map(equationIndices); + IndexSet secondIndices = secondAccessFunction.map(equationIndices); - if (!lhs->second || !rhs->second) { - return std::make_pair(static_cast(0), mlir::Value()); - } + if (firstIndices.empty() && secondIndices.empty()) { + return true; + } - mlir::Value result = builder.create( - mulOp.getLoc(), mulOp.getResult().getType(), - lhs->second, rhs->second); + if (firstAccessFunction == secondAccessFunction) { + return true; + } - return std::make_pair(lhs->first + rhs->first, result); - } + if (firstIndices.flatSize() == 1 && firstIndices == secondIndices) { + return true; + } - auto hasAccessToVar = [&](mlir::Value value, - EquationPath path) -> std::optional { - llvm::SmallVector accesses; + return false; +} - if (mlir::failed(searchAccesses( - accesses, symbolTableCollection, inductionsPositionMap, - value, path))) { - return std::nullopt; - } +mlir::LogicalResult EquationTemplateOp::groupLeftHandSide( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + std::optional equationRanges, + const VariableAccess &requestedAccess) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto inductionsPositionMap = getInductionsPositionMap(); + uint64_t viewElementIndex = requestedAccess.getPath()[0]; - bool hasAccess = llvm::any_of(accesses, [&](const VariableAccess& access) { - if (access.getVariable().getRootReference().getValue() != variable) { - return false; - } + IndexSet equationIndices; - const AccessFunction& accessFunction = access.getAccessFunction(); - IndexSet accessedIndices = accessFunction.map(equationIndices); + if (equationRanges) { + equationIndices += *equationRanges; + } - if (accessedIndices.empty() && variableIndices.empty()) { - return true; - } + auto requestedValue = getValueAtPath(requestedAccess.getPath()); - return accessedIndices.overlaps(variableIndices); - }); + // Determine whether the access to be grouped is inside both the equation's + // sides or just one of them. When the requested access is found, also + // check that the path goes through linear operations. If not, + // explicitation is not possible. + bool lhsHasAccess = false; + bool rhsHasAccess = false; - if (hasAccess) { - return true; - } + llvm::SmallVector accesses; - return false; - }; + if (mlir::failed(getAccesses(accesses, symbolTableCollection))) { + return mlir::failure(); + } - if (auto divOp = mlir::dyn_cast(op)) { - auto dividend = getMultiplyingFactor( - builder, symbolTableCollection, inductionsPositionMap, - equationIndices, divOp.getLhs(), variable, variableIndices, - path + 0); + const AccessFunction &requestedAccessFunction = + requestedAccess.getAccessFunction(); - if (!dividend) { - return std::nullopt; - } + auto requestedIndices = requestedAccessFunction.map(equationIndices); - if (!dividend->second) { - return dividend; - } + for (const VariableAccess &access : accesses) { + if (access.getVariable() != requestedAccess.getVariable()) { + continue; + } - // Check that the right-hand side value has no access to the variable - // of interest. - auto rhsHasAccess = hasAccessToVar(divOp.getRhs(), path + 1); + const AccessFunction ¤tAccessFunction = access.getAccessFunction(); + auto currentAccessIndices = currentAccessFunction.map(equationIndices); - if (!rhsHasAccess || *rhsHasAccess) { - return std::nullopt; + if ((requestedIndices.empty() && currentAccessIndices.empty()) || + requestedIndices.overlaps(currentAccessIndices)) { + if (!checkAccessEquivalence(equationIndices, requestedAccess, access)) { + return mlir::failure(); } - mlir::Value result = builder.create( - divOp.getLoc(), divOp.getResult().getType(), dividend->second, - divOp.getRhs()); + EquationPath::EquationSide side = access.getPath().getEquationSide(); + lhsHasAccess |= side == EquationPath::LEFT; + rhsHasAccess |= side == EquationPath::RIGHT; + } + } - return std::make_pair(dividend->first, result); + // Convert the expression to a sum of values. + auto convertToSumsFn = + [&](std::function()> rootFn) + -> mlir::LogicalResult { + if (auto root = rootFn(); mlir::failed( + removeSubtractions(rewriter, root.first.getDefiningOp()))) { + return mlir::failure(); } - // Check that the value is not the result of an operation using the - // variable of interest. If it has such access, then we are not able to - // extract the multiplying factor. - if (hasAccessToVar(value, path)) { - return std::make_pair(static_cast(1), mlir::Value()); + if (auto root = rootFn(); mlir::failed( + distributeMulAndDivOps(rewriter, root.first.getDefiningOp()))) { + return mlir::failure(); } - return std::make_pair(static_cast(0), value); - } + if (auto root = rootFn(); + mlir::failed(pushNegateOps(rewriter, root.first.getDefiningOp()))) { + return mlir::failure(); + } - bool EquationTemplateOp::checkAccessEquivalence( - const IndexSet& equationIndices, - const VariableAccess& firstAccess, - const VariableAccess& secondAccess) - { - const AccessFunction& firstAccessFunction = - firstAccess.getAccessFunction(); + return mlir::success(); + }; - const AccessFunction& secondAccessFunction = - secondAccess.getAccessFunction(); + llvm::SmallVector> lhsSummedValues; + llvm::SmallVector> rhsSummedValues; - IndexSet firstIndices = firstAccessFunction.map(equationIndices); - IndexSet secondIndices = secondAccessFunction.map(equationIndices); + if (lhsHasAccess) { + auto rootFn = [&]() -> std::pair { + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - if (firstIndices.empty() && secondIndices.empty()) { - return true; - } + return std::make_pair(equationSidesOp.getLhsValues()[viewElementIndex], + EquationPath(EquationPath::LEFT, viewElementIndex)); + }; - if (firstAccessFunction == secondAccessFunction) { - return true; + if (mlir::failed(convertToSumsFn(rootFn))) { + return mlir::failure(); } - if (firstIndices.flatSize() == 1 && firstIndices == secondIndices) { - return true; + if (auto root = rootFn(); mlir::failed( + collectSummedValues(lhsSummedValues, root.first, root.second))) { + return mlir::failure(); } - - return false; } - mlir::LogicalResult EquationTemplateOp::groupLeftHandSide( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - std::optional equationRanges, - const VariableAccess& requestedAccess) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - auto inductionsPositionMap = getInductionsPositionMap(); - uint64_t viewElementIndex = requestedAccess.getPath()[0]; + if (rhsHasAccess) { + auto rootFn = [&]() -> std::pair { + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - IndexSet equationIndices; + return std::make_pair( + equationSidesOp.getRhsValues()[viewElementIndex], + EquationPath(EquationPath::RIGHT, viewElementIndex)); + }; - if (equationRanges) { - equationIndices += *equationRanges; + if (mlir::failed(convertToSumsFn(rootFn))) { + return mlir::failure(); } - auto requestedValue = getValueAtPath(requestedAccess.getPath()); - - // Determine whether the access to be grouped is inside both the equation's - // sides or just one of them. When the requested access is found, also - // check that the path goes through linear operations. If not, - // explicitation is not possible. - bool lhsHasAccess = false; - bool rhsHasAccess = false; - - llvm::SmallVector accesses; - - if (mlir::failed(getAccesses(accesses, symbolTableCollection))) { + if (auto root = rootFn(); mlir::failed( + collectSummedValues(rhsSummedValues, root.first, root.second))) { return mlir::failure(); } + } - const AccessFunction& requestedAccessFunction = - requestedAccess.getAccessFunction(); - - auto requestedIndices = requestedAccessFunction.map(equationIndices); - - for (const VariableAccess& access : accesses) { - if (access.getVariable() != requestedAccess.getVariable()) { - continue; - } - - const AccessFunction& currentAccessFunction = access.getAccessFunction(); - auto currentAccessIndices = currentAccessFunction.map(equationIndices); - - if ((requestedIndices.empty() && currentAccessIndices.empty()) || - requestedIndices.overlaps(currentAccessIndices)) { - if (!checkAccessEquivalence( - equationIndices, requestedAccess, access)) { - return mlir::failure(); - } + auto containsAccessFn = + [&](bool &result, mlir::Value value, EquationPath path, + const VariableAccess &access) -> mlir::LogicalResult { + llvm::SmallVector accesses; - EquationPath::EquationSide side = access.getPath().getEquationSide(); - lhsHasAccess |= side == EquationPath::LEFT; - rhsHasAccess |= side == EquationPath::RIGHT; - } + if (mlir::failed(searchAccesses(accesses, symbolTableCollection, + inductionsPositionMap, value, path))) { + return mlir::failure(); } - // Convert the expression to a sum of values. - auto convertToSumsFn = - [&](std::function()> rootFn) - -> mlir::LogicalResult { - if (auto root = rootFn(); mlir::failed( - removeSubtractions(rewriter, root.first.getDefiningOp()))) { - return mlir::failure(); - } + const AccessFunction &accessFunction = access.getAccessFunction(); - if (auto root = rootFn(); mlir::failed( - distributeMulAndDivOps(rewriter, root.first.getDefiningOp()))) { - return mlir::failure(); + result = llvm::any_of(accesses, [&](const VariableAccess &acc) { + if (acc.getVariable() != access.getVariable()) { + return false; } - if (auto root = rootFn(); mlir::failed( - pushNegateOps(rewriter, root.first.getDefiningOp()))) { - return mlir::failure(); - } + IndexSet requestedIndices = accessFunction.map(equationIndices); - return mlir::success(); - }; + const AccessFunction ¤tAccessFunction = acc.getAccessFunction(); + auto currentIndices = currentAccessFunction.map(equationIndices); - llvm::SmallVector> lhsSummedValues; - llvm::SmallVector> rhsSummedValues; + assert(requestedIndices == currentIndices || + !requestedIndices.overlaps(currentIndices)); + return requestedIndices == currentIndices; + }); - if (lhsHasAccess) { - auto rootFn = [&]() -> std::pair { - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); + return mlir::success(); + }; - return std::make_pair( - equationSidesOp.getLhsValues()[viewElementIndex], - EquationPath(EquationPath::LEFT, viewElementIndex)); - }; + auto groupFactorsFn = [&](auto beginIt, auto endIt) -> mlir::Value { + mlir::Value result = rewriter.create( + getOperation()->getLoc(), RealAttr::get(rewriter.getContext(), 0)); - if (mlir::failed(convertToSumsFn(rootFn))) { - return mlir::failure(); - } + for (auto it = beginIt; it != endIt; ++it) { + auto factor = getMultiplyingFactor( + rewriter, symbolTableCollection, inductionsPositionMap, + equationIndices, it->first, + requestedAccess.getVariable().getRootReference().getValue(), + requestedIndices, it->second); - if (auto root = rootFn(); mlir::failed(collectSummedValues( - lhsSummedValues, root.first, root.second))) { - return mlir::failure(); + if (!factor) { + return nullptr; } - } - - if (rhsHasAccess) { - auto rootFn = [&]() -> std::pair { - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); - return std::make_pair( - equationSidesOp.getRhsValues()[viewElementIndex], - EquationPath(EquationPath::RIGHT, viewElementIndex)); - }; - - if (mlir::failed(convertToSumsFn(rootFn))) { - return mlir::failure(); + if (!factor->second || factor->first > 1) { + return nullptr; } - if (auto root = rootFn(); mlir::failed(collectSummedValues( - rhsSummedValues, root.first, root.second))) { - return mlir::failure(); - } + result = rewriter.create( + it->first.getLoc(), + getMostGenericScalarType(result.getType(), it->first.getType()), + result, factor->second); } - auto containsAccessFn = - [&](bool& result, - mlir::Value value, - EquationPath path, - const VariableAccess& access) -> mlir::LogicalResult { - llvm::SmallVector accesses; - - if (mlir::failed(searchAccesses( - accesses, symbolTableCollection, inductionsPositionMap, - value, path))) { - return mlir::failure(); - } - - const AccessFunction& accessFunction = access.getAccessFunction(); - - result = llvm::any_of(accesses, [&](const VariableAccess& acc) { - if (acc.getVariable() != access.getVariable()) { - return false; - } - - IndexSet requestedIndices = accessFunction.map(equationIndices); + return result; + }; - const AccessFunction& currentAccessFunction = acc.getAccessFunction(); - auto currentIndices = currentAccessFunction.map(equationIndices); + auto groupRemainingFn = [&](auto beginIt, auto endIt) -> mlir::Value { + auto zeroConstantOp = rewriter.create( + getOperation()->getLoc(), RealAttr::get(rewriter.getContext(), 0)); - assert(requestedIndices == currentIndices || - !requestedIndices.overlaps(currentIndices)); - return requestedIndices == currentIndices; - }); + mlir::Value result = zeroConstantOp.getResult(); - return mlir::success(); - }; + for (auto it = beginIt; it != endIt; ++it) { + mlir::Value value = it->first; - auto groupFactorsFn = [&](auto beginIt, auto endIt) -> mlir::Value { - mlir::Value result = rewriter.create( - getOperation()->getLoc(), RealAttr::get(rewriter.getContext(), 0)); + result = rewriter.create( + value.getLoc(), + getMostGenericScalarType(result.getType(), value.getType()), result, + value); + } - for (auto it = beginIt; it != endIt; ++it) { - auto factor = getMultiplyingFactor( - rewriter, symbolTableCollection, inductionsPositionMap, - equationIndices, it->first, - requestedAccess.getVariable().getRootReference().getValue(), - requestedIndices, it->second); + return result; + }; - if (!factor) { - return nullptr; - } + if (lhsHasAccess && rhsHasAccess) { + bool error = false; - if (!factor->second || factor->first > 1) { - return nullptr; - } + auto leftPos = llvm::partition(lhsSummedValues, [&](const auto &value) { + bool result = false; - result = rewriter.create( - it->first.getLoc(), - getMostGenericScalarType(result.getType(), it->first.getType()), - result, factor->second); + if (mlir::failed(containsAccessFn(result, value.first, value.second, + requestedAccess))) { + error = true; + return false; } return result; - }; - - auto groupRemainingFn = [&](auto beginIt, auto endIt) -> mlir::Value { - auto zeroConstantOp = rewriter.create( - getOperation()->getLoc(), RealAttr::get(rewriter.getContext(), 0)); - - mlir::Value result = zeroConstantOp.getResult(); + }); - for (auto it = beginIt; it != endIt; ++it) { - mlir::Value value = it->first; + auto rightPos = llvm::partition(rhsSummedValues, [&](const auto &value) { + bool result = false; - result = rewriter.create( - value.getLoc(), - getMostGenericScalarType(result.getType(), value.getType()), - result, value); + if (mlir::failed(containsAccessFn(result, value.first, value.second, + requestedAccess))) { + error = true; + return false; } return result; - }; + }); - if (lhsHasAccess && rhsHasAccess) { - bool error = false; + if (error) { + return mlir::failure(); + } - auto leftPos = llvm::partition(lhsSummedValues, [&](const auto& value) { - bool result = false; + mlir::Value lhsFactor = groupFactorsFn(lhsSummedValues.begin(), leftPos); + mlir::Value rhsFactor = groupFactorsFn(rhsSummedValues.begin(), rightPos); - if (mlir::failed(containsAccessFn( - result, value.first, value.second, requestedAccess))) { - error = true; - return false; - } + if (!lhsFactor || !rhsFactor) { + return mlir::failure(); + } - return result; - }); + mlir::Value lhsRemaining = groupRemainingFn(leftPos, lhsSummedValues.end()); + mlir::Value rhsRemaining = + groupRemainingFn(rightPos, rhsSummedValues.end()); - auto rightPos = llvm::partition(rhsSummedValues, [&](const auto& value) { - bool result = false; + auto rhs = rewriter.create( + getLoc(), requestedValue.getType(), + rewriter.create(getLoc(), + getMostGenericScalarType(rhsRemaining.getType(), + lhsRemaining.getType()), + rhsRemaining, lhsRemaining), + rewriter.create( + getLoc(), + getMostGenericScalarType(lhsFactor.getType(), rhsFactor.getType()), + lhsFactor, rhsFactor)); - if (mlir::failed(containsAccessFn( - result, value.first, value.second, requestedAccess))) { - error = true; - return false; - } + // Check if we are dividing by zero. + foldValue(rewriter, rhs.getRhs(), getBody()); - return result; - }); + if (auto divisorOp = + mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { + std::optional isZero = isZeroAttr(divisorOp.getValue()); - if (error) { + if (!isZero || *isZero) { return mlir::failure(); } + } - mlir::Value lhsFactor = groupFactorsFn(lhsSummedValues.begin(), leftPos); - mlir::Value rhsFactor = groupFactorsFn(rhsSummedValues.begin(), rightPos); + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - if (!lhsFactor || !rhsFactor) { - return mlir::failure(); - } + auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); + auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); - mlir::Value lhsRemaining = groupRemainingFn(leftPos, lhsSummedValues.end()); - mlir::Value rhsRemaining = groupRemainingFn(rightPos, rhsSummedValues.end()); + auto oldLhsValues = lhsOp.getValues(); + llvm::SmallVector newLhsValues(oldLhsValues.begin(), + oldLhsValues.end()); - auto rhs = rewriter.create( - getLoc(), requestedValue.getType(), - rewriter.create( - getLoc(), - getMostGenericScalarType( - rhsRemaining.getType(), lhsRemaining.getType()), - rhsRemaining, lhsRemaining), - rewriter.create( - getLoc(), - getMostGenericScalarType( - lhsFactor.getType(), rhsFactor.getType()), - lhsFactor, rhsFactor)); + auto oldRhsValues = rhsOp.getValues(); + llvm::SmallVector newRhsValues(oldRhsValues.begin(), + oldRhsValues.end()); - // Check if we are dividing by zero. - foldValue(rewriter, rhs.getRhs(), getBody()); + newLhsValues[viewElementIndex] = requestedValue; + newRhsValues[viewElementIndex] = rhs.getResult(); - if (auto divisorOp = - mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { - std::optional isZero = isZeroAttr(divisorOp.getValue()); + rewriter.setInsertionPoint(lhsOp); + rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); - if (!isZero || *isZero) { - return mlir::failure(); - } - } + rewriter.setInsertionPoint(rhsOp); + rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); + return mlir::success(); + } - auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); - auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); + if (lhsHasAccess) { + bool error = false; - auto oldLhsValues = lhsOp.getValues(); - llvm::SmallVector newLhsValues( - oldLhsValues.begin(), oldLhsValues.end()); + auto leftPos = llvm::partition(lhsSummedValues, [&](const auto &value) { + bool result = false; - auto oldRhsValues = rhsOp.getValues(); - llvm::SmallVector newRhsValues( - oldRhsValues.begin(), oldRhsValues.end()); + if (mlir::failed(containsAccessFn(result, value.first, value.second, + requestedAccess))) { + error = true; + return false; + } - newLhsValues[viewElementIndex] = requestedValue; - newRhsValues[viewElementIndex] = rhs.getResult(); + return result; + }); - rewriter.setInsertionPoint(lhsOp); - rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); + if (error) { + return mlir::failure(); + } - rewriter.setInsertionPoint(rhsOp); - rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); + mlir::Value lhsFactor = groupFactorsFn(lhsSummedValues.begin(), leftPos); - return mlir::success(); + if (!lhsFactor) { + return mlir::failure(); } - if (lhsHasAccess) { - bool error = false; - - auto leftPos = llvm::partition(lhsSummedValues, [&](const auto& value) { - bool result = false; + mlir::Value lhsRemaining = groupRemainingFn(leftPos, lhsSummedValues.end()); - if (mlir::failed(containsAccessFn( - result, value.first, value.second, requestedAccess))) { - error = true; - return false; - } + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - return result; - }); + auto rhs = rewriter.create( + getLoc(), requestedValue.getType(), + rewriter.create( + getLoc(), + getMostGenericScalarType( + equationSidesOp.getRhsValues()[viewElementIndex].getType(), + lhsRemaining.getType()), + equationSidesOp.getRhsValues()[viewElementIndex], lhsRemaining), + lhsFactor); - if (error) { - return mlir::failure(); - } + // Check if we are dividing by zero. + foldValue(rewriter, rhs.getRhs(), getBody()); - mlir::Value lhsFactor = groupFactorsFn(lhsSummedValues.begin(), leftPos); + if (auto divisorOp = + mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { + std::optional isZero = isZeroAttr(divisorOp.getValue()); - if (!lhsFactor) { + if (!isZero || *isZero) { return mlir::failure(); } + } - mlir::Value lhsRemaining = groupRemainingFn(leftPos, lhsSummedValues.end()); + auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); + auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); + auto oldLhsValues = lhsOp.getValues(); + llvm::SmallVector newLhsValues(oldLhsValues.begin(), + oldLhsValues.end()); - auto rhs = rewriter.create( - getLoc(), requestedValue.getType(), - rewriter.create( - getLoc(), - getMostGenericScalarType( - equationSidesOp.getRhsValues()[viewElementIndex].getType(), - lhsRemaining.getType()), - equationSidesOp.getRhsValues()[viewElementIndex], - lhsRemaining), - lhsFactor); - - // Check if we are dividing by zero. - foldValue(rewriter, rhs.getRhs(), getBody()); - - if (auto divisorOp = - mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { - std::optional isZero = isZeroAttr(divisorOp.getValue()); - - if (!isZero || *isZero) { - return mlir::failure(); - } - } + auto oldRhsValues = rhsOp.getValues(); + llvm::SmallVector newRhsValues(oldRhsValues.begin(), + oldRhsValues.end()); - auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); - auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); + newLhsValues[viewElementIndex] = requestedValue; + newRhsValues[viewElementIndex] = rhs.getResult(); - auto oldLhsValues = lhsOp.getValues(); - llvm::SmallVector newLhsValues( - oldLhsValues.begin(), oldLhsValues.end()); + rewriter.setInsertionPoint(lhsOp); + rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); - auto oldRhsValues = rhsOp.getValues(); - llvm::SmallVector newRhsValues( - oldRhsValues.begin(), oldRhsValues.end()); + rewriter.setInsertionPoint(rhsOp); + rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); - newLhsValues[viewElementIndex] = requestedValue; - newRhsValues[viewElementIndex] = rhs.getResult(); + return mlir::success(); + } - rewriter.setInsertionPoint(lhsOp); - rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); + if (rhsHasAccess) { + bool error = false; - rewriter.setInsertionPoint(rhsOp); - rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); + auto rightPos = llvm::partition(rhsSummedValues, [&](const auto &value) { + bool result = false; - return mlir::success(); - } + if (mlir::failed(containsAccessFn(result, value.first, value.second, + requestedAccess))) { + error = true; + return false; + } - if (rhsHasAccess) { - bool error = false; + return result; + }); - auto rightPos = llvm::partition(rhsSummedValues, [&](const auto& value) { - bool result = false; + if (error) { + return mlir::failure(); + } - if (mlir::failed(containsAccessFn( - result, value.first, value.second, requestedAccess))) { - error = true; - return false; - } + mlir::Value rhsFactor = groupFactorsFn(rhsSummedValues.begin(), rightPos); - return result; - }); + if (!rhsFactor) { + return mlir::failure(); + } - if (error) { - return mlir::failure(); - } + mlir::Value rhsRemaining = + groupRemainingFn(rightPos, rhsSummedValues.end()); - mlir::Value rhsFactor = groupFactorsFn(rhsSummedValues.begin(), rightPos); + auto equationSidesOp = + mlir::cast(getBody()->getTerminator()); - if (!rhsFactor) { - return mlir::failure(); - } + auto rhs = rewriter.create( + getLoc(), requestedValue.getType(), + rewriter.create( + getLoc(), + getMostGenericScalarType( + equationSidesOp.getLhsValues()[viewElementIndex].getType(), + rhsRemaining.getType()), + equationSidesOp.getLhsValues()[viewElementIndex], rhsRemaining), + rhsFactor); - mlir::Value rhsRemaining = groupRemainingFn(rightPos, rhsSummedValues.end()); + // Check if we are dividing by zero. + foldValue(rewriter, rhs.getRhs(), getBody()); - auto equationSidesOp = - mlir::cast(getBody()->getTerminator()); + if (auto divisorOp = + mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { + std::optional isZero = isZeroAttr(divisorOp.getValue()); - auto rhs = rewriter.create( - getLoc(), requestedValue.getType(), - rewriter.create( - getLoc(), - getMostGenericScalarType( - equationSidesOp.getLhsValues()[viewElementIndex].getType(), - rhsRemaining.getType()), - equationSidesOp.getLhsValues()[viewElementIndex], - rhsRemaining), - rhsFactor); - - // Check if we are dividing by zero. - foldValue(rewriter, rhs.getRhs(), getBody()); - - if (auto divisorOp = - mlir::dyn_cast(rhs.getRhs().getDefiningOp())) { - std::optional isZero = isZeroAttr(divisorOp.getValue()); - - if (!isZero || *isZero) { - return mlir::failure(); - } + if (!isZero || *isZero) { + return mlir::failure(); } + } - auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); - auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); - - auto oldLhsValues = lhsOp.getValues(); - llvm::SmallVector newLhsValues( - oldLhsValues.begin(), oldLhsValues.end()); + auto lhsOp = equationSidesOp.getLhs().getDefiningOp(); + auto rhsOp = equationSidesOp.getRhs().getDefiningOp(); - auto oldRhsValues = rhsOp.getValues(); - llvm::SmallVector newRhsValues( - oldRhsValues.begin(), oldRhsValues.end()); + auto oldLhsValues = lhsOp.getValues(); + llvm::SmallVector newLhsValues(oldLhsValues.begin(), + oldLhsValues.end()); - newLhsValues[viewElementIndex] = requestedValue; - newRhsValues[viewElementIndex] = rhs.getResult(); + auto oldRhsValues = rhsOp.getValues(); + llvm::SmallVector newRhsValues(oldRhsValues.begin(), + oldRhsValues.end()); - rewriter.setInsertionPoint(lhsOp); - rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); + newLhsValues[viewElementIndex] = requestedValue; + newRhsValues[viewElementIndex] = rhs.getResult(); - rewriter.setInsertionPoint(rhsOp); - rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); + rewriter.setInsertionPoint(lhsOp); + rewriter.replaceOpWithNewOp(lhsOp, newLhsValues); - return mlir::success(); - } + rewriter.setInsertionPoint(rhsOp); + rewriter.replaceOpWithNewOp(rhsOp, newRhsValues); - llvm_unreachable("Access not found"); - return mlir::failure(); + return mlir::success(); } + + llvm_unreachable("Access not found"); + return mlir::failure(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // EquationFunctionOp -namespace mlir::bmodelica -{ - mlir::ParseResult EquationFunctionOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto buildFuncType = - [](mlir::Builder& builder, - llvm::ArrayRef argTypes, - llvm::ArrayRef results, - mlir::function_interface_impl::VariadicFlag, - std::string&) { - return builder.getFunctionType(argTypes, results); - }; +namespace mlir::bmodelica { +mlir::ParseResult EquationFunctionOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; - return mlir::function_interface_impl::parseFunctionOp( - parser, result, false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); - } + return mlir::function_interface_impl::parseFunctionOp( + parser, result, false, getFunctionTypeAttrName(result.name), + buildFuncType, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); +} - void EquationFunctionOp::print(OpAsmPrinter& printer) - { - mlir::function_interface_impl::printFunctionOp( - printer, *this, false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); - } +void EquationFunctionOp::print(OpAsmPrinter &printer) { + mlir::function_interface_impl::printFunctionOp( + printer, *this, false, getFunctionTypeAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()); +} - void EquationFunctionOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name, - uint64_t numOfInductions, - llvm::ArrayRef attrs, - llvm::ArrayRef argAttrs) - { - state.addAttribute( - mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); +void EquationFunctionOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + llvm::StringRef name, uint64_t numOfInductions, + llvm::ArrayRef attrs, + llvm::ArrayRef argAttrs) { + state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); - llvm::SmallVector argTypes( - numOfInductions * 2, builder.getIndexType()); + llvm::SmallVector argTypes(numOfInductions * 2, + builder.getIndexType()); - auto functionType = builder.getFunctionType(argTypes, std::nullopt); + auto functionType = builder.getFunctionType(argTypes, std::nullopt); - state.addAttribute( - getFunctionTypeAttrName(state.name), - mlir::TypeAttr::get(functionType)); + state.addAttribute(getFunctionTypeAttrName(state.name), + mlir::TypeAttr::get(functionType)); - state.attributes.append(attrs.begin(), attrs.end()); - state.addRegion(); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); - if (argAttrs.empty()) { - return; - } + if (argAttrs.empty()) { + return; + } - assert(functionType.getNumInputs() == argAttrs.size()); + assert(functionType.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( - builder, state, argAttrs, std::nullopt, - getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); - } + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, std::nullopt, getArgAttrsAttrName(state.name), + getResAttrsAttrName(state.name)); +} - mlir::Value EquationFunctionOp::getLowerBound(uint64_t induction) - { - return getArgument(induction * 2); - } +mlir::Value EquationFunctionOp::getLowerBound(uint64_t induction) { + return getArgument(induction * 2); +} - mlir::Value EquationFunctionOp::getUpperBound(uint64_t induction) - { - return getArgument(induction * 2 + 1); - } +mlir::Value EquationFunctionOp::getUpperBound(uint64_t induction) { + return getArgument(induction * 2 + 1); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // InitialOp -namespace -{ - struct EmptyInitialModelPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - InitialOp op, mlir::PatternRewriter& rewriter) const override - { - if (op.getBody()->empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } +namespace { +struct EmptyInitialModelPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - return mlir::failure(); + mlir::LogicalResult + matchAndRewrite(InitialOp op, + mlir::PatternRewriter &rewriter) const override { + if (op.getBody()->empty()) { + rewriter.eraseOp(op); + return mlir::success(); } - }; -} -namespace mlir::bmodelica -{ - void InitialOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + return mlir::failure(); } +}; +} // namespace - void InitialOp::collectSCCs(llvm::SmallVectorImpl& SCCs) - { - for (SCCOp scc : getOps()) { - SCCs.push_back(scc); - } +namespace mlir::bmodelica { +void InitialOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} + +void InitialOp::collectSCCs(llvm::SmallVectorImpl &SCCs) { + for (SCCOp scc : getOps()) { + SCCs.push_back(scc); } +} - void InitialOp::collectAlgorithms( - llvm::SmallVectorImpl& algorithms) - { - for (AlgorithmOp algorithm : getOps()) { - algorithms.push_back(algorithm); - } +void InitialOp::collectAlgorithms( + llvm::SmallVectorImpl &algorithms) { + for (AlgorithmOp algorithm : getOps()) { + algorithms.push_back(algorithm); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DynamicOp -namespace -{ - struct EmptyMainModelPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - DynamicOp op, mlir::PatternRewriter& rewriter) const override - { - if (op.getBody()->empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } +namespace { +struct EmptyMainModelPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - return mlir::failure(); + mlir::LogicalResult + matchAndRewrite(DynamicOp op, + mlir::PatternRewriter &rewriter) const override { + if (op.getBody()->empty()) { + rewriter.eraseOp(op); + return mlir::success(); } - }; -} -namespace mlir::bmodelica -{ - void DynamicOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + return mlir::failure(); } +}; +} // namespace - void DynamicOp::collectSCCs(llvm::SmallVectorImpl& SCCs) - { - for (SCCOp scc : getOps()) { - SCCs.push_back(scc); - } +namespace mlir::bmodelica { +void DynamicOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} + +void DynamicOp::collectSCCs(llvm::SmallVectorImpl &SCCs) { + for (SCCOp scc : getOps()) { + SCCs.push_back(scc); } - - void DynamicOp::collectAlgorithms( - llvm::SmallVectorImpl& algorithms) - { - for (AlgorithmOp algorithm : getOps()) { - algorithms.push_back(algorithm); - } +} + +void DynamicOp::collectAlgorithms( + llvm::SmallVectorImpl &algorithms) { + for (AlgorithmOp algorithm : getOps()) { + algorithms.push_back(algorithm); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // StartEquationInstanceOp -namespace mlir::bmodelica -{ - void StartEquationInstanceOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - EquationTemplateOp equationTemplate) - { - build(builder, state, equationTemplate.getResult(), nullptr); - } - - mlir::LogicalResult StartEquationInstanceOp::verify() - { - auto indicesRank = - [&](std::optional ranges) -> size_t { - if (!ranges) { - return 0; - } +namespace mlir::bmodelica { +void StartEquationInstanceOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + EquationTemplateOp equationTemplate) { + build(builder, state, equationTemplate.getResult(), nullptr); +} - return ranges->getValue().rank(); - }; +mlir::LogicalResult StartEquationInstanceOp::verify() { + auto indicesRank = + [&](std::optional ranges) -> size_t { + if (!ranges) { + return 0; + } - // Check the indices for the explicit inductions. - size_t numOfExplicitInductions = getInductionVariables().size(); + return ranges->getValue().rank(); + }; - if (size_t explicitIndicesRank = indicesRank(getIndices()); - numOfExplicitInductions != explicitIndicesRank) { - return emitOpError() - << "Unexpected rank of iteration indices (expected " - << numOfExplicitInductions << ", got " << explicitIndicesRank << ")"; - } + // Check the indices for the explicit inductions. + size_t numOfExplicitInductions = getInductionVariables().size(); - return mlir::success(); + if (size_t explicitIndicesRank = indicesRank(getIndices()); + numOfExplicitInductions != explicitIndicesRank) { + return emitOpError() << "Unexpected rank of iteration indices (expected " + << numOfExplicitInductions << ", got " + << explicitIndicesRank << ")"; } - EquationTemplateOp StartEquationInstanceOp::getTemplate() - { - auto result = getBase().getDefiningOp(); - assert(result != nullptr); - return result; - } + return mlir::success(); +} - void StartEquationInstanceOp::printInline(llvm::raw_ostream& os) - { - getTemplate().printInline(os); - } +EquationTemplateOp StartEquationInstanceOp::getTemplate() { + auto result = getBase().getDefiningOp(); + assert(result != nullptr); + return result; +} - mlir::ValueRange StartEquationInstanceOp::getInductionVariables() - { - return getTemplate().getInductionVariables(); - } +void StartEquationInstanceOp::printInline(llvm::raw_ostream &os) { + getTemplate().printInline(os); +} - IndexSet StartEquationInstanceOp::getIterationSpace() - { - if (auto indices = getIndices()) { - return {indices->getValue()}; - } +mlir::ValueRange StartEquationInstanceOp::getInductionVariables() { + return getTemplate().getInductionVariables(); +} - return {}; +IndexSet StartEquationInstanceOp::getIterationSpace() { + if (auto indices = getIndices()) { + return {indices->getValue()}; } - std::optional StartEquationInstanceOp::getWriteAccess( - mlir::SymbolTableCollection& symbolTableCollection) - { - return getAccessAtPath(symbolTableCollection, - EquationPath(EquationPath::LEFT, 0)); - } + return {}; +} - mlir::LogicalResult StartEquationInstanceOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - return getTemplate().getAccesses(result, symbolTable); - } +std::optional StartEquationInstanceOp::getWriteAccess( + mlir::SymbolTableCollection &symbolTableCollection) { + return getAccessAtPath(symbolTableCollection, + EquationPath(EquationPath::LEFT, 0)); +} - mlir::LogicalResult StartEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::ArrayRef accesses) - { - return getReadAccesses( - result, symbolTableCollection, getIterationSpace(), accesses); - } +mlir::LogicalResult StartEquationInstanceOp::getAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + return getTemplate().getAccesses(result, symbolTable); +} - mlir::LogicalResult StartEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - const IndexSet& equationIndices, - llvm::ArrayRef accesses) - { - std::optional writeAccess = - getWriteAccess(symbolTableCollection); +mlir::LogicalResult StartEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::ArrayRef accesses) { + return getReadAccesses(result, symbolTableCollection, getIterationSpace(), + accesses); +} - if (!writeAccess) { - return mlir::failure(); - } +mlir::LogicalResult StartEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + const IndexSet &equationIndices, llvm::ArrayRef accesses) { + std::optional writeAccess = + getWriteAccess(symbolTableCollection); - return getTemplate().getReadAccesses( - result, equationIndices, accesses, *writeAccess); + if (!writeAccess) { + return mlir::failure(); } - std::optional StartEquationInstanceOp::getAccessAtPath( - mlir::SymbolTableCollection& symbolTable, - const EquationPath& path) - { - return getTemplate().getAccessAtPath(symbolTable, path); - } + return getTemplate().getReadAccesses(result, equationIndices, accesses, + *writeAccess); } +std::optional StartEquationInstanceOp::getAccessAtPath( + mlir::SymbolTableCollection &symbolTable, const EquationPath &path) { + return getTemplate().getAccessAtPath(symbolTable, path); +} +} // namespace mlir::bmodelica + //===---------------------------------------------------------------------===// // EquationInstanceOp -namespace mlir::bmodelica -{ - void EquationInstanceOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - EquationTemplateOp equationTemplate) - { - build(builder, state, equationTemplate.getResult(), nullptr); - } - - mlir::LogicalResult EquationInstanceOp::verify() - { - auto indicesRank = - [&](std::optional ranges) -> size_t { - if (!ranges) { - return 0; - } - - return ranges->getValue().rank(); - }; - - // Check the indices for the explicit inductions. - size_t numOfExplicitInductions = getInductionVariables().size(); +namespace mlir::bmodelica { +void EquationInstanceOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + EquationTemplateOp equationTemplate) { + build(builder, state, equationTemplate.getResult(), nullptr); +} - if (size_t explicitIndicesRank = indicesRank(getIndices()); - numOfExplicitInductions != explicitIndicesRank) { - return emitOpError() - << "Unexpected rank of iteration indices (expected " - << numOfExplicitInductions << ", got " << explicitIndicesRank << ")"; +mlir::LogicalResult EquationInstanceOp::verify() { + auto indicesRank = + [&](std::optional ranges) -> size_t { + if (!ranges) { + return 0; } - return mlir::success(); - } + return ranges->getValue().rank(); + }; - EquationTemplateOp EquationInstanceOp::getTemplate() - { - auto result = getBase().getDefiningOp(); - assert(result != nullptr); - return result; - } + // Check the indices for the explicit inductions. + size_t numOfExplicitInductions = getInductionVariables().size(); - void EquationInstanceOp::printInline(llvm::raw_ostream& os) - { - getTemplate().printInline(os); + if (size_t explicitIndicesRank = indicesRank(getIndices()); + numOfExplicitInductions != explicitIndicesRank) { + return emitOpError() << "Unexpected rank of iteration indices (expected " + << numOfExplicitInductions << ", got " + << explicitIndicesRank << ")"; } - mlir::ValueRange EquationInstanceOp::getInductionVariables() - { - return getTemplate().getInductionVariables(); - } + return mlir::success(); +} - IndexSet EquationInstanceOp::getIterationSpace() - { - if (auto indices = getIndices()) { - return IndexSet(indices->getValue()); - } +EquationTemplateOp EquationInstanceOp::getTemplate() { + auto result = getBase().getDefiningOp(); + assert(result != nullptr); + return result; +} - return {}; - } +void EquationInstanceOp::printInline(llvm::raw_ostream &os) { + getTemplate().printInline(os); +} - mlir::LogicalResult EquationInstanceOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - return getTemplate().getAccesses(result, symbolTable); - } +mlir::ValueRange EquationInstanceOp::getInductionVariables() { + return getTemplate().getInductionVariables(); +} - std::optional EquationInstanceOp::getAccessAtPath( - mlir::SymbolTableCollection& symbolTable, - const EquationPath& path) - { - return getTemplate().getAccessAtPath(symbolTable, path); +IndexSet EquationInstanceOp::getIterationSpace() { + if (auto indices = getIndices()) { + return IndexSet(indices->getValue()); } - mlir::LogicalResult EquationInstanceOp::cloneWithReplacedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - llvm::SmallVectorImpl& results) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); + return {}; +} - auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { - llvm::SmallVector templateOps; +mlir::LogicalResult +EquationInstanceOp::getAccesses(llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + return getTemplate().getAccesses(result, symbolTable); +} - for (EquationInstanceOp equationOp : results) { - templateOps.push_back(equationOp.getTemplate()); - } +std::optional +EquationInstanceOp::getAccessAtPath(mlir::SymbolTableCollection &symbolTable, + const EquationPath &path) { + return getTemplate().getAccessAtPath(symbolTable, path); +} - (void) cleanEquationTemplates(rewriter, templateOps); - }); +mlir::LogicalResult EquationInstanceOp::cloneWithReplacedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + llvm::SmallVectorImpl &results) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - llvm::SmallVector> templateResults; + auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { + llvm::SmallVector templateOps; - if (mlir::failed(getTemplate().cloneWithReplacedAccess( - rewriter, equationIndices, access, replacementEquation, - replacementAccess, templateResults))) { - return mlir::failure(); + for (EquationInstanceOp equationOp : results) { + templateOps.push_back(equationOp.getTemplate()); } - rewriter.setInsertionPointAfter(getOperation()); + (void)cleanEquationTemplates(rewriter, templateOps); + }); + + llvm::SmallVector> templateResults; + + if (mlir::failed(getTemplate().cloneWithReplacedAccess( + rewriter, equationIndices, access, replacementEquation, + replacementAccess, templateResults))) { + return mlir::failure(); + } + + rewriter.setInsertionPointAfter(getOperation()); - auto temporaryClonedOp = mlir::cast( - rewriter.clone(*getOperation())); + auto temporaryClonedOp = + mlir::cast(rewriter.clone(*getOperation())); - for (auto& [assignedIndices, equationTemplateOp] : templateResults) { - if (assignedIndices.empty()) { + for (auto &[assignedIndices, equationTemplateOp] : templateResults) { + if (assignedIndices.empty()) { + auto clonedOp = mlir::cast( + rewriter.clone(*temporaryClonedOp.getOperation())); + + clonedOp.setOperand(equationTemplateOp.getResult()); + clonedOp.removeIndicesAttr(); + results.push_back(clonedOp); + } else { + for (const MultidimensionalRange &assignedIndicesRange : llvm::make_range( + assignedIndices.rangesBegin(), assignedIndices.rangesEnd())) { auto clonedOp = mlir::cast( rewriter.clone(*temporaryClonedOp.getOperation())); clonedOp.setOperand(equationTemplateOp.getResult()); - clonedOp.removeIndicesAttr(); - results.push_back(clonedOp); - } else { - for (const MultidimensionalRange& assignedIndicesRange : - llvm::make_range(assignedIndices.rangesBegin(), - assignedIndices.rangesEnd())) { - auto clonedOp = mlir::cast( - rewriter.clone(*temporaryClonedOp.getOperation())); - - clonedOp.setOperand(equationTemplateOp.getResult()); - - if (auto explicitIndices = getIndices()) { - MultidimensionalRange explicitRange = - assignedIndicesRange.takeFirstDimensions( - explicitIndices->getValue().rank()); - - clonedOp.setIndicesAttr( - MultidimensionalRangeAttr::get( - rewriter.getContext(), std::move(explicitRange))); - } - results.push_back(clonedOp); + if (auto explicitIndices = getIndices()) { + MultidimensionalRange explicitRange = + assignedIndicesRange.takeFirstDimensions( + explicitIndices->getValue().rank()); + + clonedOp.setIndicesAttr(MultidimensionalRangeAttr::get( + rewriter.getContext(), std::move(explicitRange))); } + + results.push_back(clonedOp); } } - - rewriter.eraseOp(temporaryClonedOp); - return mlir::success(); } + + rewriter.eraseOp(temporaryClonedOp); + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // MatchedEquationInstanceOp -namespace mlir::bmodelica -{ - void MatchedEquationInstanceOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - EquationTemplateOp equationTemplate, - EquationPathAttr path) - { - build(builder, state, equationTemplate.getResult(), nullptr, path); - } - - mlir::LogicalResult MatchedEquationInstanceOp::verify() - { - auto indicesRank = - [&](std::optional ranges) -> size_t { - if (!ranges) { - return 0; - } - - return ranges->getValue().rank(); - }; - - // Check the indices for the explicit inductions. - size_t numOfExplicitInductions = getInductionVariables().size(); +namespace mlir::bmodelica { +void MatchedEquationInstanceOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + EquationTemplateOp equationTemplate, + EquationPathAttr path) { + build(builder, state, equationTemplate.getResult(), nullptr, path); +} - if (size_t explicitIndicesRank = indicesRank(getIndices()); - numOfExplicitInductions != explicitIndicesRank) { - return emitOpError() - << "Unexpected rank of iteration indices (expected " - << numOfExplicitInductions << ", got " << explicitIndicesRank << ")"; +mlir::LogicalResult MatchedEquationInstanceOp::verify() { + auto indicesRank = + [&](std::optional ranges) -> size_t { + if (!ranges) { + return 0; } - return mlir::success(); - } + return ranges->getValue().rank(); + }; - EquationTemplateOp MatchedEquationInstanceOp::getTemplate() - { - auto result = getBase().getDefiningOp(); - assert(result != nullptr); - return result; - } + // Check the indices for the explicit inductions. + size_t numOfExplicitInductions = getInductionVariables().size(); - void MatchedEquationInstanceOp::printInline(llvm::raw_ostream& os) - { - getTemplate().printInline(os); + if (size_t explicitIndicesRank = indicesRank(getIndices()); + numOfExplicitInductions != explicitIndicesRank) { + return emitOpError() << "Unexpected rank of iteration indices (expected " + << numOfExplicitInductions << ", got " + << explicitIndicesRank << ")"; } - mlir::ValueRange MatchedEquationInstanceOp::getInductionVariables() - { - return getTemplate().getInductionVariables(); - } + return mlir::success(); +} - IndexSet MatchedEquationInstanceOp::getIterationSpace() - { - if (auto indices = getIndices()) { - return IndexSet(indices->getValue()); - } +EquationTemplateOp MatchedEquationInstanceOp::getTemplate() { + auto result = getBase().getDefiningOp(); + assert(result != nullptr); + return result; +} - return {}; - } +void MatchedEquationInstanceOp::printInline(llvm::raw_ostream &os) { + getTemplate().printInline(os); +} - std::optional MatchedEquationInstanceOp::getMatchedAccess( - mlir::SymbolTableCollection& symbolTableCollection) - { - return getAccessAtPath(symbolTableCollection, getPath().getValue()); - } +mlir::ValueRange MatchedEquationInstanceOp::getInductionVariables() { + return getTemplate().getInductionVariables(); +} - mlir::LogicalResult MatchedEquationInstanceOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - return getTemplate().getAccesses(result, symbolTable); +IndexSet MatchedEquationInstanceOp::getIterationSpace() { + if (auto indices = getIndices()) { + return IndexSet(indices->getValue()); } - mlir::LogicalResult MatchedEquationInstanceOp::getWriteAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::ArrayRef accesses) - { - return getWriteAccesses( - result, symbolTableCollection, getIterationSpace(), accesses); - } + return {}; +} - mlir::LogicalResult MatchedEquationInstanceOp::getWriteAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - const IndexSet& equationIndices, - llvm::ArrayRef accesses) - { - std::optional matchedAccess = - getMatchedAccess(symbolTableCollection); +std::optional MatchedEquationInstanceOp::getMatchedAccess( + mlir::SymbolTableCollection &symbolTableCollection) { + return getAccessAtPath(symbolTableCollection, getPath().getValue()); +} - if (!matchedAccess) { - return mlir::failure(); - } +mlir::LogicalResult MatchedEquationInstanceOp::getAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + return getTemplate().getAccesses(result, symbolTable); +} - return getTemplate().getWriteAccesses( - result, equationIndices, accesses, *matchedAccess); - } +mlir::LogicalResult MatchedEquationInstanceOp::getWriteAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::ArrayRef accesses) { + return getWriteAccesses(result, symbolTableCollection, getIterationSpace(), + accesses); +} - mlir::LogicalResult MatchedEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::ArrayRef accesses) - { - return getReadAccesses( - result, symbolTableCollection, getIterationSpace(), accesses); +mlir::LogicalResult MatchedEquationInstanceOp::getWriteAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + const IndexSet &equationIndices, llvm::ArrayRef accesses) { + std::optional matchedAccess = + getMatchedAccess(symbolTableCollection); + + if (!matchedAccess) { + return mlir::failure(); } - mlir::LogicalResult MatchedEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - const IndexSet& equationIndices, - llvm::ArrayRef accesses) - { - std::optional matchedAccess = - getMatchedAccess(symbolTableCollection); + return getTemplate().getWriteAccesses(result, equationIndices, accesses, + *matchedAccess); +} - if (!matchedAccess) { - return mlir::failure(); - } +mlir::LogicalResult MatchedEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::ArrayRef accesses) { + return getReadAccesses(result, symbolTableCollection, getIterationSpace(), + accesses); +} - return getTemplate().getReadAccesses( - result, equationIndices, accesses, *matchedAccess); - } +mlir::LogicalResult MatchedEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + const IndexSet &equationIndices, llvm::ArrayRef accesses) { + std::optional matchedAccess = + getMatchedAccess(symbolTableCollection); - std::optional MatchedEquationInstanceOp::getAccessAtPath( - mlir::SymbolTableCollection& symbolTable, - const EquationPath& path) - { - return getTemplate().getAccessAtPath(symbolTable, path); + if (!matchedAccess) { + return mlir::failure(); } - mlir::LogicalResult MatchedEquationInstanceOp::explicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection) - { - std::optional indices = std::nullopt; + return getTemplate().getReadAccesses(result, equationIndices, accesses, + *matchedAccess); +} - if (auto indicesAttr = getIndices()) { - indices = indicesAttr->getValue(); - } +std::optional MatchedEquationInstanceOp::getAccessAtPath( + mlir::SymbolTableCollection &symbolTable, const EquationPath &path) { + return getTemplate().getAccessAtPath(symbolTable, path); +} - if (mlir::failed(getTemplate().explicitate( - rewriter, symbolTableCollection, indices, getPath().getValue()))) { - return mlir::failure(); - } +mlir::LogicalResult MatchedEquationInstanceOp::explicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection) { + std::optional indices = std::nullopt; - setPathAttr(EquationPathAttr::get( - getContext(), EquationPath(EquationPath::LEFT, 0))); + if (auto indicesAttr = getIndices()) { + indices = indicesAttr->getValue(); + } - return mlir::success(); + if (mlir::failed(getTemplate().explicitate(rewriter, symbolTableCollection, + indices, getPath().getValue()))) { + return mlir::failure(); } - MatchedEquationInstanceOp MatchedEquationInstanceOp::cloneAndExplicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection) - { - std::optional indices = std::nullopt; + setPathAttr( + EquationPathAttr::get(getContext(), EquationPath(EquationPath::LEFT, 0))); - if (auto indicesAttr = getIndices()) { - indices = indicesAttr->getValue(); - } + return mlir::success(); +} - EquationTemplateOp clonedTemplate = getTemplate().cloneAndExplicitate( - rewriter, symbolTableCollection, indices, getPath().getValue()); +MatchedEquationInstanceOp MatchedEquationInstanceOp::cloneAndExplicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection) { + std::optional indices = std::nullopt; - if (!clonedTemplate) { - return nullptr; - } + if (auto indicesAttr = getIndices()) { + indices = indicesAttr->getValue(); + } - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(getOperation()); + EquationTemplateOp clonedTemplate = getTemplate().cloneAndExplicitate( + rewriter, symbolTableCollection, indices, getPath().getValue()); - auto result = rewriter.create( - getLoc(), clonedTemplate, - EquationPathAttr::get( - getContext(), - EquationPath(EquationPath::LEFT, 0))); + if (!clonedTemplate) { + return nullptr; + } - if (indices) { - result.setIndicesAttr(MultidimensionalRangeAttr::get( - getContext(), *indices)); - } + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(getOperation()); - return result; + auto result = rewriter.create( + getLoc(), clonedTemplate, + EquationPathAttr::get(getContext(), EquationPath(EquationPath::LEFT, 0))); + + if (indices) { + result.setIndicesAttr( + MultidimensionalRangeAttr::get(getContext(), *indices)); } - mlir::LogicalResult MatchedEquationInstanceOp::cloneWithReplacedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - llvm::SmallVectorImpl& results) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); + return result; +} - auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { - llvm::SmallVector templateOps; +mlir::LogicalResult MatchedEquationInstanceOp::cloneWithReplacedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + llvm::SmallVectorImpl &results) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - for (MatchedEquationInstanceOp equationOp : results) { - templateOps.push_back(equationOp.getTemplate()); - } + auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { + llvm::SmallVector templateOps; - (void) cleanEquationTemplates(rewriter, templateOps); - }); + for (MatchedEquationInstanceOp equationOp : results) { + templateOps.push_back(equationOp.getTemplate()); + } - llvm::SmallVector> templateResults; + (void)cleanEquationTemplates(rewriter, templateOps); + }); - if (mlir::failed(getTemplate().cloneWithReplacedAccess( - rewriter, equationIndices, access, replacementEquation, - replacementAccess, templateResults))) { - return mlir::failure(); - } + llvm::SmallVector> templateResults; + + if (mlir::failed(getTemplate().cloneWithReplacedAccess( + rewriter, equationIndices, access, replacementEquation, + replacementAccess, templateResults))) { + return mlir::failure(); + } - rewriter.setInsertionPointAfter(getOperation()); + rewriter.setInsertionPointAfter(getOperation()); - auto temporaryClonedOp = mlir::cast( - rewriter.clone(*getOperation())); + auto temporaryClonedOp = + mlir::cast(rewriter.clone(*getOperation())); - for (auto& [assignedIndices, equationTemplateOp] : templateResults) { - if (assignedIndices.empty()) { + for (auto &[assignedIndices, equationTemplateOp] : templateResults) { + if (assignedIndices.empty()) { + auto clonedOp = mlir::cast( + rewriter.clone(*temporaryClonedOp.getOperation())); + + clonedOp.setOperand(equationTemplateOp.getResult()); + clonedOp.removeIndicesAttr(); + results.push_back(clonedOp); + } else { + for (const MultidimensionalRange &assignedIndicesRange : llvm::make_range( + assignedIndices.rangesBegin(), assignedIndices.rangesEnd())) { auto clonedOp = mlir::cast( rewriter.clone(*temporaryClonedOp.getOperation())); - clonedOp.setOperand(equationTemplateOp.getResult()); - clonedOp.removeIndicesAttr(); - results.push_back(clonedOp); - } else { - for (const MultidimensionalRange& assignedIndicesRange : - llvm::make_range(assignedIndices.rangesBegin(), - assignedIndices.rangesEnd())) { - auto clonedOp = mlir::cast( - rewriter.clone(*temporaryClonedOp.getOperation())); - - clonedOp.setOperand(equationTemplateOp.getResult()); - - if (auto explicitIndices = getIndices()) { - MultidimensionalRange explicitRange = - assignedIndicesRange.takeFirstDimensions( - explicitIndices->getValue().rank()); - - clonedOp.setIndicesAttr( - MultidimensionalRangeAttr::get( - rewriter.getContext(), std::move(explicitRange))); - } + clonedOp.setOperand(equationTemplateOp.getResult()); + + if (auto explicitIndices = getIndices()) { + MultidimensionalRange explicitRange = + assignedIndicesRange.takeFirstDimensions( + explicitIndices->getValue().rank()); - results.push_back(clonedOp); + clonedOp.setIndicesAttr(MultidimensionalRangeAttr::get( + rewriter.getContext(), std::move(explicitRange))); } + + results.push_back(clonedOp); } } - - rewriter.eraseOp(temporaryClonedOp); - return mlir::success(); } + + rewriter.eraseOp(temporaryClonedOp); + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SCCGroupOp -namespace mlir::bmodelica -{ - mlir::RegionKind SCCGroupOp::getRegionKind(unsigned index) - { - return mlir::RegionKind::Graph; - } +namespace mlir::bmodelica { +mlir::RegionKind SCCGroupOp::getRegionKind(unsigned index) { + return mlir::RegionKind::Graph; +} - void SCCGroupOp::collectSCCs(llvm::SmallVectorImpl& SCCs) - { - for (SCCOp scc : getOps()) { - SCCs.push_back(scc); - } +void SCCGroupOp::collectSCCs(llvm::SmallVectorImpl &SCCs) { + for (SCCOp scc : getOps()) { + SCCs.push_back(scc); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // SCCOp -namespace -{ - struct EmptySCCPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - SCCOp op, mlir::PatternRewriter& rewriter) const override - { - if (op.getBody()->empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } +namespace { +struct EmptySCCPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - return mlir::failure(); + mlir::LogicalResult + matchAndRewrite(SCCOp op, mlir::PatternRewriter &rewriter) const override { + if (op.getBody()->empty()) { + rewriter.eraseOp(op); + return mlir::success(); } - }; -} -namespace mlir::bmodelica -{ - void SCCOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + return mlir::failure(); } +}; +} // namespace - mlir::RegionKind SCCOp::getRegionKind(unsigned index) - { - return mlir::RegionKind::Graph; - } +namespace mlir::bmodelica { +void SCCOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); +} + +mlir::RegionKind SCCOp::getRegionKind(unsigned index) { + return mlir::RegionKind::Graph; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ScheduledEquationInstanceOp -namespace mlir::bmodelica -{ - void ScheduledEquationInstanceOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - EquationTemplateOp equationTemplate, - EquationPathAttr path, - mlir::ArrayAttr iterationDirections) - { - build(builder, state, equationTemplate.getResult(), nullptr, path, - iterationDirections); - } - - mlir::LogicalResult ScheduledEquationInstanceOp::verify() - { - auto indicesRankFn = - [&](std::optional ranges) -> size_t { - if (!ranges) { - return 0; - } - - return ranges->getValue().rank(); - }; - - // Check the indices for the explicit inductions. - size_t numOfInductions = getInductionVariables().size(); +namespace mlir::bmodelica { +void ScheduledEquationInstanceOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + EquationTemplateOp equationTemplate, + EquationPathAttr path, + mlir::ArrayAttr iterationDirections) { + build(builder, state, equationTemplate.getResult(), nullptr, path, + iterationDirections); +} - if (size_t indicesRank = indicesRankFn(getIndices()); - numOfInductions != indicesRank) { - return emitOpError() - << "Unexpected rank of iteration indices (expected " - << numOfInductions << ", got " << indicesRank << ")"; +mlir::LogicalResult ScheduledEquationInstanceOp::verify() { + auto indicesRankFn = + [&](std::optional ranges) -> size_t { + if (!ranges) { + return 0; } - // Check the iteration directions. - if (size_t numOfIterationDirections = getIterationDirections().size(); - numOfInductions != numOfIterationDirections) { - return emitOpError() - << "Unexpected number of iteration directions (expected " - << numOfInductions << ", got " << numOfIterationDirections - << ")"; - } + return ranges->getValue().rank(); + }; - return mlir::success(); - } + // Check the indices for the explicit inductions. + size_t numOfInductions = getInductionVariables().size(); - EquationTemplateOp ScheduledEquationInstanceOp::getTemplate() - { - auto result = getBase().getDefiningOp(); - assert(result != nullptr); - return result; + if (size_t indicesRank = indicesRankFn(getIndices()); + numOfInductions != indicesRank) { + return emitOpError() << "Unexpected rank of iteration indices (expected " + << numOfInductions << ", got " << indicesRank << ")"; } - void ScheduledEquationInstanceOp::printInline(llvm::raw_ostream& os) - { - getTemplate().printInline(os); + // Check the iteration directions. + if (size_t numOfIterationDirections = getIterationDirections().size(); + numOfInductions != numOfIterationDirections) { + return emitOpError() + << "Unexpected number of iteration directions (expected " + << numOfInductions << ", got " << numOfIterationDirections << ")"; } - mlir::ValueRange ScheduledEquationInstanceOp::getInductionVariables() - { - return getTemplate().getInductionVariables(); - } + return mlir::success(); +} - mlir::LogicalResult ScheduledEquationInstanceOp::getAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTable) - { - return getTemplate().getAccesses(result, symbolTable); - } +EquationTemplateOp ScheduledEquationInstanceOp::getTemplate() { + auto result = getBase().getDefiningOp(); + assert(result != nullptr); + return result; +} - std::optional ScheduledEquationInstanceOp::getAccessAtPath( - mlir::SymbolTableCollection& symbolTable, - const EquationPath& path) - { - return getTemplate().getAccessAtPath(symbolTable, path); - } +void ScheduledEquationInstanceOp::printInline(llvm::raw_ostream &os) { + getTemplate().printInline(os); +} - IndexSet ScheduledEquationInstanceOp::getIterationSpace() - { - if (auto indices = getIndices()) { - return IndexSet(indices->getValue()); - } +mlir::ValueRange ScheduledEquationInstanceOp::getInductionVariables() { + return getTemplate().getInductionVariables(); +} - return {}; - } +mlir::LogicalResult ScheduledEquationInstanceOp::getAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTable) { + return getTemplate().getAccesses(result, symbolTable); +} - std::optional ScheduledEquationInstanceOp::getMatchedAccess( - mlir::SymbolTableCollection& symbolTableCollection) - { - return getAccessAtPath(symbolTableCollection, getPath().getValue()); - } +std::optional ScheduledEquationInstanceOp::getAccessAtPath( + mlir::SymbolTableCollection &symbolTable, const EquationPath &path) { + return getTemplate().getAccessAtPath(symbolTable, path); +} - mlir::LogicalResult ScheduledEquationInstanceOp::getWriteAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::ArrayRef accesses) - { - return getWriteAccesses( - result, symbolTableCollection, getIterationSpace(), accesses); +IndexSet ScheduledEquationInstanceOp::getIterationSpace() { + if (auto indices = getIndices()) { + return IndexSet(indices->getValue()); } - mlir::LogicalResult ScheduledEquationInstanceOp::getWriteAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - const IndexSet& equationIndices, - llvm::ArrayRef accesses) - { - std::optional matchedAccess = - getMatchedAccess(symbolTableCollection); + return {}; +} - if (!matchedAccess) { - return mlir::failure(); - } +std::optional ScheduledEquationInstanceOp::getMatchedAccess( + mlir::SymbolTableCollection &symbolTableCollection) { + return getAccessAtPath(symbolTableCollection, getPath().getValue()); +} - return getTemplate().getWriteAccesses( - result, equationIndices, accesses, *matchedAccess); - } +mlir::LogicalResult ScheduledEquationInstanceOp::getWriteAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::ArrayRef accesses) { + return getWriteAccesses(result, symbolTableCollection, getIterationSpace(), + accesses); +} + +mlir::LogicalResult ScheduledEquationInstanceOp::getWriteAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + const IndexSet &equationIndices, llvm::ArrayRef accesses) { + std::optional matchedAccess = + getMatchedAccess(symbolTableCollection); - mlir::LogicalResult ScheduledEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - llvm::ArrayRef accesses) - { - return getReadAccesses( - result, symbolTableCollection, getIterationSpace(), accesses); + if (!matchedAccess) { + return mlir::failure(); } - mlir::LogicalResult ScheduledEquationInstanceOp::getReadAccesses( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - const IndexSet& equationIndices, - llvm::ArrayRef accesses) - { - std::optional matchedAccess = - getMatchedAccess(symbolTableCollection); + return getTemplate().getWriteAccesses(result, equationIndices, accesses, + *matchedAccess); +} - if (!matchedAccess) { - return mlir::failure(); - } +mlir::LogicalResult ScheduledEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::ArrayRef accesses) { + return getReadAccesses(result, symbolTableCollection, getIterationSpace(), + accesses); +} - return getTemplate().getReadAccesses( - result, equationIndices, accesses, *matchedAccess); - } +mlir::LogicalResult ScheduledEquationInstanceOp::getReadAccesses( + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, + const IndexSet &equationIndices, llvm::ArrayRef accesses) { + std::optional matchedAccess = + getMatchedAccess(symbolTableCollection); - mlir::LogicalResult ScheduledEquationInstanceOp::explicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection) - { - std::optional indices = std::nullopt; + if (!matchedAccess) { + return mlir::failure(); + } - if (auto indicesAttr = getIndices()) { - indices = indicesAttr->getValue(); - } + return getTemplate().getReadAccesses(result, equationIndices, accesses, + *matchedAccess); +} - if (mlir::failed(getTemplate().explicitate( - rewriter, symbolTableCollection, indices, getPath().getValue()))) { - return mlir::failure(); - } +mlir::LogicalResult ScheduledEquationInstanceOp::explicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection) { + std::optional indices = std::nullopt; - setPathAttr(EquationPathAttr::get( - getContext(), EquationPath(EquationPath::LEFT, 0))); + if (auto indicesAttr = getIndices()) { + indices = indicesAttr->getValue(); + } - return mlir::success(); + if (mlir::failed(getTemplate().explicitate(rewriter, symbolTableCollection, + indices, getPath().getValue()))) { + return mlir::failure(); } - ScheduledEquationInstanceOp ScheduledEquationInstanceOp::cloneAndExplicitate( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection) - { - std::optional indices = std::nullopt; + setPathAttr( + EquationPathAttr::get(getContext(), EquationPath(EquationPath::LEFT, 0))); - if (auto indicesAttr = getIndices()) { - indices = indicesAttr->getValue(); - } + return mlir::success(); +} - EquationTemplateOp clonedTemplate = getTemplate().cloneAndExplicitate( - rewriter, symbolTableCollection, indices, getPath().getValue()); +ScheduledEquationInstanceOp ScheduledEquationInstanceOp::cloneAndExplicitate( + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection) { + std::optional indices = std::nullopt; - if (!clonedTemplate) { - return nullptr; - } + if (auto indicesAttr = getIndices()) { + indices = indicesAttr->getValue(); + } + + EquationTemplateOp clonedTemplate = getTemplate().cloneAndExplicitate( + rewriter, symbolTableCollection, indices, getPath().getValue()); - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(getOperation()); + if (!clonedTemplate) { + return nullptr; + } - auto result = rewriter.create( - getLoc(), clonedTemplate, - EquationPathAttr::get( - getContext(), - EquationPath(EquationPath::LEFT, 0)), - getIterationDirections()); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(getOperation()); - if (indices) { - result.setIndicesAttr(MultidimensionalRangeAttr::get( - getContext(), *indices)); - } + auto result = rewriter.create( + getLoc(), clonedTemplate, + EquationPathAttr::get(getContext(), EquationPath(EquationPath::LEFT, 0)), + getIterationDirections()); - return result; + if (indices) { + result.setIndicesAttr( + MultidimensionalRangeAttr::get(getContext(), *indices)); } - mlir::LogicalResult ScheduledEquationInstanceOp::cloneWithReplacedAccess( - mlir::RewriterBase& rewriter, - std::optional> equationIndices, - const VariableAccess& access, - EquationTemplateOp replacementEquation, - const VariableAccess& replacementAccess, - llvm::SmallVectorImpl& results) - { - mlir::OpBuilder::InsertionGuard guard(rewriter); + return result; +} - auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { - llvm::SmallVector templateOps; +mlir::LogicalResult ScheduledEquationInstanceOp::cloneWithReplacedAccess( + mlir::RewriterBase &rewriter, + std::optional> equationIndices, + const VariableAccess &access, EquationTemplateOp replacementEquation, + const VariableAccess &replacementAccess, + llvm::SmallVectorImpl &results) { + mlir::OpBuilder::InsertionGuard guard(rewriter); - for (ScheduledEquationInstanceOp equationOp : results) { - templateOps.push_back(equationOp.getTemplate()); - } + auto cleanTemplatesOnExit = llvm::make_scope_exit([&]() { + llvm::SmallVector templateOps; - (void) cleanEquationTemplates(rewriter, templateOps); - }); + for (ScheduledEquationInstanceOp equationOp : results) { + templateOps.push_back(equationOp.getTemplate()); + } - llvm::SmallVector> templateResults; + (void)cleanEquationTemplates(rewriter, templateOps); + }); - if (mlir::failed(getTemplate().cloneWithReplacedAccess( - rewriter, equationIndices, access, replacementEquation, - replacementAccess, templateResults))) { - return mlir::failure(); - } + llvm::SmallVector> templateResults; + + if (mlir::failed(getTemplate().cloneWithReplacedAccess( + rewriter, equationIndices, access, replacementEquation, + replacementAccess, templateResults))) { + return mlir::failure(); + } + + rewriter.setInsertionPointAfter(getOperation()); - rewriter.setInsertionPointAfter(getOperation()); + auto temporaryClonedOp = + mlir::cast(rewriter.clone(*getOperation())); - auto temporaryClonedOp = mlir::cast( - rewriter.clone(*getOperation())); + for (auto &[assignedIndices, equationTemplateOp] : templateResults) { + if (assignedIndices.empty()) { + auto clonedOp = mlir::cast( + rewriter.clone(*temporaryClonedOp.getOperation())); - for (auto& [assignedIndices, equationTemplateOp] : templateResults) { - if (assignedIndices.empty()) { + clonedOp.setOperand(equationTemplateOp.getResult()); + clonedOp.removeIndicesAttr(); + results.push_back(clonedOp); + } else { + for (const MultidimensionalRange &assignedIndicesRange : llvm::make_range( + assignedIndices.rangesBegin(), assignedIndices.rangesEnd())) { auto clonedOp = mlir::cast( rewriter.clone(*temporaryClonedOp.getOperation())); clonedOp.setOperand(equationTemplateOp.getResult()); - clonedOp.removeIndicesAttr(); - results.push_back(clonedOp); - } else { - for (const MultidimensionalRange& assignedIndicesRange : - llvm::make_range(assignedIndices.rangesBegin(), - assignedIndices.rangesEnd())) { - auto clonedOp = mlir::cast( - rewriter.clone(*temporaryClonedOp.getOperation())); - - clonedOp.setOperand(equationTemplateOp.getResult()); - - if (auto explicitIndices = getIndices()) { - MultidimensionalRange explicitRange = - assignedIndicesRange.takeFirstDimensions( - explicitIndices->getValue().rank()); - - clonedOp.setIndicesAttr( - MultidimensionalRangeAttr::get( - rewriter.getContext(), std::move(explicitRange))); - } - results.push_back(clonedOp); + if (auto explicitIndices = getIndices()) { + MultidimensionalRange explicitRange = + assignedIndicesRange.takeFirstDimensions( + explicitIndices->getValue().rank()); + + clonedOp.setIndicesAttr(MultidimensionalRangeAttr::get( + rewriter.getContext(), std::move(explicitRange))); } + + results.push_back(clonedOp); } } - - rewriter.eraseOp(temporaryClonedOp); - return mlir::success(); } + + rewriter.eraseOp(temporaryClonedOp); + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // EquationSideOp -namespace -{ - struct EquationSideTypePropagationPattern - : public mlir::OpRewritePattern - { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - EquationSideOp op, mlir::PatternRewriter& rewriter) const override - { - bool different = false; - llvm::SmallVector newTypes; - - for (size_t i = 0, e = op.getValues().size(); i < e; ++i) { - mlir::Type existingType = op.getResult().getType().getType(i); - mlir::Type expectedType = op.getValues()[i].getType(); - - if (existingType != expectedType) { - different = true; - } +namespace { +struct EquationSideTypePropagationPattern + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - newTypes.push_back(expectedType); - } + mlir::LogicalResult + matchAndRewrite(EquationSideOp op, + mlir::PatternRewriter &rewriter) const override { + bool different = false; + llvm::SmallVector newTypes; - if (!different) { - return mlir::failure(); + for (size_t i = 0, e = op.getValues().size(); i < e; ++i) { + mlir::Type existingType = op.getResult().getType().getType(i); + mlir::Type expectedType = op.getValues()[i].getType(); + + if (existingType != expectedType) { + different = true; } - rewriter.replaceOpWithNewOp(op, op.getValues()); - return mlir::failure(); + newTypes.push_back(expectedType); } - }; -} -namespace mlir::bmodelica -{ - mlir::ParseResult EquationSideOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - llvm::SmallVector values; - mlir::Type resultType; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(values) || - parser.parseColon() || - parser.parseType(resultType)) { + if (!different) { return mlir::failure(); } - assert(resultType.isa()); - auto tupleType = resultType.cast(); - - llvm::SmallVector types(tupleType.begin(), tupleType.end()); - assert(types.size() == values.size()); + rewriter.replaceOpWithNewOp(op, op.getValues()); + return mlir::failure(); + } +}; +} // namespace - if (parser.resolveOperands(values, types, loc, result.operands)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::ParseResult EquationSideOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + llvm::SmallVector values; + mlir::Type resultType; + auto loc = parser.getCurrentLocation(); - result.addTypes(resultType); - return mlir::success(); + if (parser.parseOperandList(values) || parser.parseColon() || + parser.parseType(resultType)) { + return mlir::failure(); } - void EquationSideOp::print(mlir::OpAsmPrinter& printer) - { - printer.printOptionalAttrDict(getOperation()->getAttrs()); - printer << " " << getValues() << " : " << getResult().getType(); - } + assert(resultType.isa()); + auto tupleType = resultType.cast(); - void EquationSideOp::getCanonicalizationPatterns( - mlir::RewritePatternSet& patterns, mlir::MLIRContext* context) - { - patterns.add(context); + llvm::SmallVector types(tupleType.begin(), tupleType.end()); + assert(types.size() == values.size()); + + if (parser.resolveOperands(values, types, loc, result.operands)) { + return mlir::failure(); } + + result.addTypes(resultType); + return mlir::success(); +} + +void EquationSideOp::print(mlir::OpAsmPrinter &printer) { + printer.printOptionalAttrDict(getOperation()->getAttrs()); + printer << " " << getValues() << " : " << getResult().getType(); } +void EquationSideOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); +} +} // namespace mlir::bmodelica + //===---------------------------------------------------------------------===// // FunctionOp -namespace mlir::bmodelica -{ - void FunctionOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name) - { - build(builder, state, name, nullptr); - } +namespace mlir::bmodelica { +void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name) { + build(builder, state, name, nullptr); +} - llvm::SmallVector FunctionOp::getArgumentTypes() - { - llvm::SmallVector types; +llvm::SmallVector FunctionOp::getArgumentTypes() { + llvm::SmallVector types; - for (VariableOp variableOp : getVariables()) { - VariableType variableType = variableOp.getVariableType(); + for (VariableOp variableOp : getVariables()) { + VariableType variableType = variableOp.getVariableType(); - if (variableType.isInput()) { - types.push_back(variableType.unwrap()); - } + if (variableType.isInput()) { + types.push_back(variableType.unwrap()); } - - return types; } - llvm::SmallVector FunctionOp::getResultTypes() - { - llvm::SmallVector types; + return types; +} + +llvm::SmallVector FunctionOp::getResultTypes() { + llvm::SmallVector types; - for (VariableOp variableOp : getVariables()) { - VariableType variableType = variableOp.getVariableType(); + for (VariableOp variableOp : getVariables()) { + VariableType variableType = variableOp.getVariableType(); - if (variableType.isOutput()) { - types.push_back(variableType.unwrap()); - } + if (variableType.isOutput()) { + types.push_back(variableType.unwrap()); } - - return types; } - mlir::FunctionType FunctionOp::getFunctionType() - { - llvm::SmallVector argTypes; - llvm::SmallVector resultTypes; - - for (VariableOp variableOp : getVariables()) { - VariableType variableType = variableOp.getVariableType(); + return types; +} - if (variableType.isInput()) { - argTypes.push_back(variableType.unwrap()); - } else if (variableType.isOutput()) { - resultTypes.push_back(variableType.unwrap()); - } - } +mlir::FunctionType FunctionOp::getFunctionType() { + llvm::SmallVector argTypes; + llvm::SmallVector resultTypes; - return mlir::FunctionType::get(getContext(), argTypes, resultTypes); - } + for (VariableOp variableOp : getVariables()) { + VariableType variableType = variableOp.getVariableType(); - bool FunctionOp::shouldBeInlined() - { - if (!getOperation()->hasAttrOfType("inline")) { - return false; + if (variableType.isInput()) { + argTypes.push_back(variableType.unwrap()); + } else if (variableType.isOutput()) { + resultTypes.push_back(variableType.unwrap()); } + } - auto inlineAttribute = - getOperation()->getAttrOfType("inline"); + return mlir::FunctionType::get(getContext(), argTypes, resultTypes); +} - return inlineAttribute.getValue(); +bool FunctionOp::shouldBeInlined() { + if (!getOperation()->hasAttrOfType("inline")) { + return false; } + + auto inlineAttribute = + getOperation()->getAttrOfType("inline"); + + return inlineAttribute.getValue(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // DerFunctionOp -namespace mlir::bmodelica -{ - mlir::ParseResult DerFunctionOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::StringAttr nameAttr; - - if (parser.parseSymbolName( - nameAttr, - mlir::SymbolTable::getSymbolAttrName(), - result.attributes) || - parser.parseOptionalAttrDict(result.attributes)) { - return mlir::failure(); - } +namespace mlir::bmodelica { +mlir::ParseResult DerFunctionOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::StringAttr nameAttr; - return mlir::success(); + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); } - void DerFunctionOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printSymbolName(getSymName()); + return mlir::success(); +} + +void DerFunctionOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getSymName()); - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(mlir::SymbolTable::getSymbolAttrName()); - printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - } + printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RawFunctionOp -namespace mlir::bmodelica -{ - RawFunctionOp RawFunctionOp::create( - mlir::Location location, - llvm::StringRef name, - mlir::FunctionType type, - llvm::ArrayRef attrs) - { - mlir::OpBuilder builder(location->getContext()); - mlir::OperationState state(location, getOperationName()); - RawFunctionOp::build(builder, state, name, type, attrs); - return mlir::cast(mlir::Operation::create(state)); - } - - RawFunctionOp RawFunctionOp::create( - mlir::Location location, - llvm::StringRef name, - mlir::FunctionType type, - mlir::Operation::dialect_attr_range attrs) - { - llvm::SmallVector attrRef(attrs); - return create(location, name, type, llvm::ArrayRef(attrRef)); - } - - RawFunctionOp RawFunctionOp::create( - mlir::Location location, - llvm::StringRef name, - mlir::FunctionType type, - llvm::ArrayRef attrs, - llvm::ArrayRef argAttrs) - { - RawFunctionOp func = create(location, name, type, attrs); - func.setAllArgAttrs(argAttrs); - return func; - } - - void RawFunctionOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - llvm::StringRef name, - mlir::FunctionType type, - llvm::ArrayRef attrs, - llvm::ArrayRef argAttrs) - { - state.addAttribute( - mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); - - state.addAttribute( - getFunctionTypeAttrName(state.name), - mlir::TypeAttr::get(type)); - - state.attributes.append(attrs.begin(), attrs.end()); - state.addRegion(); - - if (argAttrs.empty()) { - return; - } - - assert(type.getNumInputs() == argAttrs.size()); - - function_interface_impl::addArgAndResultAttrs( - builder, state, argAttrs, std::nullopt, - getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); - } - - mlir::ParseResult RawFunctionOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto buildFuncType = - [](mlir::Builder& builder, - llvm::ArrayRef argTypes, - llvm::ArrayRef results, - mlir::function_interface_impl::VariadicFlag, - std::string&) { - return builder.getFunctionType(argTypes, results); - }; - - return mlir::function_interface_impl::parseFunctionOp( - parser, result, false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); - } - - void RawFunctionOp::print(OpAsmPrinter& printer) - { - mlir::function_interface_impl::printFunctionOp( - printer, *this, false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); - } - - bool RawFunctionOp::shouldBeInlined() - { - if (!getOperation()->hasAttrOfType("inline")) { - return false; - } +namespace mlir::bmodelica { +RawFunctionOp +RawFunctionOp::create(mlir::Location location, llvm::StringRef name, + mlir::FunctionType type, + llvm::ArrayRef attrs) { + mlir::OpBuilder builder(location->getContext()); + mlir::OperationState state(location, getOperationName()); + RawFunctionOp::build(builder, state, name, type, attrs); + return mlir::cast(mlir::Operation::create(state)); +} + +RawFunctionOp RawFunctionOp::create(mlir::Location location, + llvm::StringRef name, + mlir::FunctionType type, + mlir::Operation::dialect_attr_range attrs) { + llvm::SmallVector attrRef(attrs); + return create(location, name, type, llvm::ArrayRef(attrRef)); +} + +RawFunctionOp +RawFunctionOp::create(mlir::Location location, llvm::StringRef name, + mlir::FunctionType type, + llvm::ArrayRef attrs, + llvm::ArrayRef argAttrs) { + RawFunctionOp func = create(location, name, type, attrs); + func.setAllArgAttrs(argAttrs); + return func; +} + +void RawFunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs, + llvm::ArrayRef argAttrs) { + state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); - auto inlineAttribute = - getOperation()->getAttrOfType("inline"); + state.addAttribute(getFunctionTypeAttrName(state.name), + mlir::TypeAttr::get(type)); - return inlineAttribute.getValue(); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) { + return; } - /// Clone the internal blocks from this function into dest and all attributes - /// from this function to dest. - void RawFunctionOp::cloneInto( - RawFunctionOp dest, mlir::IRMapping& mapper) - { - // Add the attributes of this function to dest. - llvm::MapVector newAttrMap; + assert(type.getNumInputs() == argAttrs.size()); - for (const auto &attr : dest->getAttrs()) { - newAttrMap.insert({attr.getName(), attr.getValue()}); - } + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, std::nullopt, getArgAttrsAttrName(state.name), + getResAttrsAttrName(state.name)); +} - for (const auto &attr : (*this)->getAttrs()) { - newAttrMap.insert({attr.getName(), attr.getValue()}); - } +mlir::ParseResult RawFunctionOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, false, getFunctionTypeAttrName(result.name), + buildFuncType, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); +} + +void RawFunctionOp::print(OpAsmPrinter &printer) { + mlir::function_interface_impl::printFunctionOp( + printer, *this, false, getFunctionTypeAttrName(), getArgAttrsAttrName(), + getResAttrsAttrName()); +} - auto newAttrs = llvm::to_vector(llvm::map_range( - newAttrMap, [](std::pair attrPair) { - return NamedAttribute(attrPair.first, attrPair.second); - })); +bool RawFunctionOp::shouldBeInlined() { + if (!getOperation()->hasAttrOfType("inline")) { + return false; + } + + auto inlineAttribute = + getOperation()->getAttrOfType("inline"); + + return inlineAttribute.getValue(); +} - dest->setAttrs(mlir::DictionaryAttr::get(getContext(), newAttrs)); +/// Clone the internal blocks from this function into dest and all attributes +/// from this function to dest. +void RawFunctionOp::cloneInto(RawFunctionOp dest, mlir::IRMapping &mapper) { + // Add the attributes of this function to dest. + llvm::MapVector newAttrMap; - // Clone the body. - getBody().cloneInto(&dest.getBody(), mapper); + for (const auto &attr : dest->getAttrs()) { + newAttrMap.insert({attr.getName(), attr.getValue()}); } - /// Create a deep copy of this function and all of its blocks, remapping - /// any operands that use values outside of the function using the map that is - /// provided (leaving them alone if no entry is present). Replaces references - /// to cloned sub-values with the corresponding value that is copied, and adds - /// those mappings to the mapper. - RawFunctionOp RawFunctionOp::clone(mlir::IRMapping& mapper) - { - // Create the new function. - RawFunctionOp newFunc = cast( - getOperation()->cloneWithoutRegions()); + for (const auto &attr : (*this)->getAttrs()) { + newAttrMap.insert({attr.getName(), attr.getValue()}); + } - // If the function has a body, then the user might be deleting arguments to - // the function by specifying them in the mapper. If so, we don't add the - // argument to the input type vector. - if (!isExternal()) { - mlir::FunctionType oldType = getFunctionType(); + auto newAttrs = llvm::to_vector(llvm::map_range( + newAttrMap, [](std::pair attrPair) { + return NamedAttribute(attrPair.first, attrPair.second); + })); - unsigned oldNumArgs = oldType.getNumInputs(); - llvm::SmallVector newInputs; - newInputs.reserve(oldNumArgs); + dest->setAttrs(mlir::DictionaryAttr::get(getContext(), newAttrs)); - for (unsigned i = 0; i != oldNumArgs; ++i) { - if (!mapper.contains(getArgument(i))) { - newInputs.push_back(oldType.getInput(i)); - } + // Clone the body. + getBody().cloneInto(&dest.getBody(), mapper); +} + +/// Create a deep copy of this function and all of its blocks, remapping +/// any operands that use values outside of the function using the map that is +/// provided (leaving them alone if no entry is present). Replaces references +/// to cloned sub-values with the corresponding value that is copied, and adds +/// those mappings to the mapper. +RawFunctionOp RawFunctionOp::clone(mlir::IRMapping &mapper) { + // Create the new function. + RawFunctionOp newFunc = + cast(getOperation()->cloneWithoutRegions()); + + // If the function has a body, then the user might be deleting arguments to + // the function by specifying them in the mapper. If so, we don't add the + // argument to the input type vector. + if (!isExternal()) { + mlir::FunctionType oldType = getFunctionType(); + + unsigned oldNumArgs = oldType.getNumInputs(); + llvm::SmallVector newInputs; + newInputs.reserve(oldNumArgs); + + for (unsigned i = 0; i != oldNumArgs; ++i) { + if (!mapper.contains(getArgument(i))) { + newInputs.push_back(oldType.getInput(i)); } + } - // If any of the arguments were dropped, update the type and drop any - // necessary argument attributes. - if (newInputs.size() != oldNumArgs) { - newFunc.setType(mlir::FunctionType::get( - oldType.getContext(), newInputs, oldType.getResults())); + // If any of the arguments were dropped, update the type and drop any + // necessary argument attributes. + if (newInputs.size() != oldNumArgs) { + newFunc.setType(mlir::FunctionType::get(oldType.getContext(), newInputs, + oldType.getResults())); - if (mlir::ArrayAttr argAttrs = getAllArgAttrs()) { - SmallVector newArgAttrs; - newArgAttrs.reserve(newInputs.size()); + if (mlir::ArrayAttr argAttrs = getAllArgAttrs()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(newInputs.size()); - for (unsigned i = 0; i != oldNumArgs; ++i) { - if (!mapper.contains(getArgument(i))) { - newArgAttrs.push_back(argAttrs[i]); - } + for (unsigned i = 0; i != oldNumArgs; ++i) { + if (!mapper.contains(getArgument(i))) { + newArgAttrs.push_back(argAttrs[i]); } - - newFunc.setAllArgAttrs(newArgAttrs); } + + newFunc.setAllArgAttrs(newArgAttrs); } } - - // Clone the current function into the new one and return it. - cloneInto(newFunc, mapper); - return newFunc; } - RawFunctionOp RawFunctionOp::clone() - { - mlir::IRMapping mapper; - return clone(mapper); - } + // Clone the current function into the new one and return it. + cloneInto(newFunc, mapper); + return newFunc; +} + +RawFunctionOp RawFunctionOp::clone() { + mlir::IRMapping mapper; + return clone(mapper); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RawReturnOp -namespace mlir::bmodelica -{ - mlir::LogicalResult RawReturnOp::verify() { - auto function = cast((*this)->getParentOp()); +namespace mlir::bmodelica { +mlir::LogicalResult RawReturnOp::verify() { + auto function = cast((*this)->getParentOp()); - // The operand number and types must match the function signature - const auto& results = function.getFunctionType().getResults(); + // The operand number and types must match the function signature + auto results = function.getFunctionType().getResults(); - if (getNumOperands() != results.size()) { - return emitOpError("has ") - << getNumOperands() << " operands, but enclosing function (@" - << function.getName() << ") returns " << results.size(); - } + if (getNumOperands() != results.size()) { + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + } - for (unsigned i = 0, e = results.size(); i != e; ++i) { - if (getOperand(i).getType() != results[i]) { - return emitOpError() - << "type of return operand " << i << " (" - << getOperand(i).getType() - << ") doesn't match function result type (" << results[i] << ")" - << " in function @" << function.getName(); - } + for (unsigned i = 0, e = results.size(); i != e; ++i) { + if (getOperand(i).getType() != results[i]) { + return emitOpError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); } - - return mlir::success(); } + + return mlir::success(); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RawVariableOp -namespace mlir::bmodelica +namespace mlir::bmodelica { +/* +mlir::ParseResult RawVariableOp::parse( + mlir::OpAsmParser& parser, mlir::OperationState& result) { - /* - mlir::ParseResult RawVariableOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - auto& builder = parser.getBuilder(); - - llvm::SmallVector dynamicSizes; - mlir::Type variableType; - - if (parser.parseOperandList(dynamicSizes) || - parser.resolveOperands( - dynamicSizes, builder.getIndexType(), result.operands)) { - return mlir::failure(); - } + auto& builder = parser.getBuilder(); - // Dimensions constraints. - llvm::SmallVector dimensionsConstraints; + llvm::SmallVector dynamicSizes; + mlir::Type variableType; - if (mlir::succeeded(parser.parseOptionalLSquare())) { - do { - if (mlir::succeeded( - parser.parseOptionalKeyword(kDimensionConstraintUnbounded))) { - dimensionsConstraints.push_back(kDimensionConstraintUnbounded); - } else { - if (parser.parseKeyword(kDimensionConstraintFixed)) { - return mlir::failure(); - } + if (parser.parseOperandList(dynamicSizes) || + parser.resolveOperands( + dynamicSizes, builder.getIndexType(), result.operands)) { + return mlir::failure(); + } + + // Dimensions constraints. + llvm::SmallVector dimensionsConstraints; - dimensionsConstraints.push_back(kDimensionConstraintFixed); + if (mlir::succeeded(parser.parseOptionalLSquare())) { + do { + if (mlir::succeeded( + parser.parseOptionalKeyword(kDimensionConstraintUnbounded))) { + dimensionsConstraints.push_back(kDimensionConstraintUnbounded); + } else { + if (parser.parseKeyword(kDimensionConstraintFixed)) { + return mlir::failure(); } - } while (mlir::succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) { - return mlir::failure(); + dimensionsConstraints.push_back(kDimensionConstraintFixed); } - } - - result.attributes.append( - getDimensionsConstraintsAttrName(result.name), - builder.getStrArrayAttr(dimensionsConstraints)); + } while (mlir::succeeded(parser.parseOptionalComma())); - // Attributes. - if (parser.parseOptionalAttrDict(result.attributes)) { + if (parser.parseRSquare()) { return mlir::failure(); } + } - // Variable type. - if (parser.parseColon() || - parser.parseType(variableType)) { - return mlir::failure(); - } + result.attributes.append( + getDimensionsConstraintsAttrName(result.name), + builder.getStrArrayAttr(dimensionsConstraints)); - result.addTypes(variableType); + // Attributes. + if (parser.parseOptionalAttrDict(result.attributes)) { + return mlir::failure(); + } - return mlir::success(); + // Variable type. + if (parser.parseColon() || + parser.parseType(variableType)) { + return mlir::failure(); } - void RawVariableOp::print(mlir::OpAsmPrinter& printer) - { - if (auto dynamicSizes = getDynamicSizes(); !dynamicSizes.empty()) { - printer << " " << dynamicSizes; - } + result.addTypes(variableType); + + return mlir::success(); +} - auto dimConstraints = - getDimensionsConstraints().getAsRange(); +void RawVariableOp::print(mlir::OpAsmPrinter& printer) +{ + if (auto dynamicSizes = getDynamicSizes(); !dynamicSizes.empty()) { + printer << " " << dynamicSizes; + } - if (llvm::any_of(dimConstraints, [](mlir::StringAttr constraint) { - return constraint == kDimensionConstraintFixed; - })) { - printer << " ["; + auto dimConstraints = + getDimensionsConstraints().getAsRange(); - for (const auto& constraint : llvm::enumerate(dimConstraints)) { - if (constraint.index() != 0) { - printer << ", "; - } + if (llvm::any_of(dimConstraints, [](mlir::StringAttr constraint) { + return constraint == kDimensionConstraintFixed; + })) { + printer << " ["; - printer << constraint.value().getValue(); + for (const auto& constraint : llvm::enumerate(dimConstraints)) { + if (constraint.index() != 0) { + printer << ", "; } - printer << "] "; + printer << constraint.value().getValue(); } - llvm::SmallVector elidedAttrs; - elidedAttrs.push_back(getDimensionsConstraintsAttrName()); - - printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - - printer << " : " << getVariable().getType(); + printer << "] "; } - */ - void RawVariableOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - auto variableMemRefType = - getVariable().getType().dyn_cast(); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back(getDimensionsConstraintsAttrName()); - if (variableMemRefType) { - if (!getHeap() || isDynamicArrayVariable()) { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::AutomaticAllocationScopeResource::get()); - } else { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - getResult(), - mlir::SideEffects::DefaultResource::get()); - } - } - } + printer.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - /* - VariableType RawVariableOp::getVariableType() - { - mlir::Type variableType = getVariable().getType(); + printer << " : " << getVariable().getType(); +} +*/ - VariabilityProperty variabilityProperty = VariabilityProperty::none; - IOProperty ioProperty = IOProperty::none; +void RawVariableOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + auto variableMemRefType = + getVariable().getType().dyn_cast(); - if (getOutput()) { - ioProperty = IOProperty::output; + if (variableMemRefType) { + if (!getHeap() || isDynamicArrayVariable()) { + effects.emplace_back( + mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::AutomaticAllocationScopeResource::get()); + } else { + effects.emplace_back(mlir::MemoryEffects::Allocate::get(), getResult(), + mlir::SideEffects::DefaultResource::get()); } + } +} - if (auto shapedType = variableType.dyn_cast()) { - return VariableType::get( - shapedType.getShape(), shapedType.getElementType(), - variabilityProperty, ioProperty); - } +/* +VariableType RawVariableOp::getVariableType() +{ + mlir::Type variableType = getVariable().getType(); - return VariableType::get( - std::nullopt, variableType, variabilityProperty, ioProperty); + VariabilityProperty variabilityProperty = VariabilityProperty::none; + IOProperty ioProperty = IOProperty::none; + + if (getOutput()) { + ioProperty = IOProperty::output; } - */ - bool RawVariableOp::isScalarVariable(mlir::Type variableType) - { - auto variableShapedType = variableType.cast(); - return variableShapedType.getShape().empty(); + if (auto shapedType = variableType.dyn_cast()) { + return VariableType::get( + shapedType.getShape(), shapedType.getElementType(), + variabilityProperty, ioProperty); } - bool RawVariableOp::isStaticArrayVariable(mlir::Type variableType) - { - auto variableShapedType = variableType.cast(); + return VariableType::get( + std::nullopt, variableType, variabilityProperty, ioProperty); +} + */ - return !variableShapedType.getShape().empty() && - variableShapedType.hasStaticShape(); - } +bool RawVariableOp::isScalarVariable(mlir::Type variableType) { + auto variableShapedType = variableType.cast(); + return variableShapedType.getShape().empty(); +} - bool RawVariableOp::isDynamicArrayVariable(mlir::Type variableType) - { - auto variableShapedType = variableType.cast(); +bool RawVariableOp::isStaticArrayVariable(mlir::Type variableType) { + auto variableShapedType = variableType.cast(); - return !variableShapedType.getShape().empty() && - !variableShapedType.hasStaticShape(); - } + return !variableShapedType.getShape().empty() && + variableShapedType.hasStaticShape(); +} - bool RawVariableOp::isScalarVariable() - { - return RawVariableOp::isScalarVariable(getVariable().getType()); - } +bool RawVariableOp::isDynamicArrayVariable(mlir::Type variableType) { + auto variableShapedType = variableType.cast(); - bool RawVariableOp::isStaticArrayVariable() - { - return RawVariableOp::isStaticArrayVariable(getVariable().getType()); - } + return !variableShapedType.getShape().empty() && + !variableShapedType.hasStaticShape(); +} - bool RawVariableOp::isDynamicArrayVariable() - { - return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); - } +bool RawVariableOp::isScalarVariable() { + return RawVariableOp::isScalarVariable(getVariable().getType()); +} - bool RawVariableOp::isProtected() - { - return !getOutput(); - } +bool RawVariableOp::isStaticArrayVariable() { + return RawVariableOp::isStaticArrayVariable(getVariable().getType()); +} - bool RawVariableOp::isOutput() - { - return getOutput(); - } +bool RawVariableOp::isDynamicArrayVariable() { + return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); } -namespace mlir::bmodelica -{ - void RawVariableDeallocOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - auto variableMemRefType = - getVariable().getType().dyn_cast(); - - if (variableMemRefType) { - effects.emplace_back( - mlir::MemoryEffects::Free::get(), - getVariable(), - mlir::SideEffects::DefaultResource::get()); - } - } +bool RawVariableOp::isProtected() { return !getOutput(); } - bool RawVariableDeallocOp::isScalarVariable() - { - return RawVariableOp::isScalarVariable(getVariable().getType()); - } +bool RawVariableOp::isOutput() { return getOutput(); } +} // namespace mlir::bmodelica - bool RawVariableDeallocOp::isStaticArrayVariable() - { - return RawVariableOp::isStaticArrayVariable(getVariable().getType()); - } +namespace mlir::bmodelica { +void RawVariableDeallocOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + auto variableMemRefType = + getVariable().getType().dyn_cast(); - bool RawVariableDeallocOp::isDynamicArrayVariable() - { - return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); + if (variableMemRefType) { + effects.emplace_back(mlir::MemoryEffects::Free::get(), getVariable(), + mlir::SideEffects::DefaultResource::get()); } } +bool RawVariableDeallocOp::isScalarVariable() { + return RawVariableOp::isScalarVariable(getVariable().getType()); +} + +bool RawVariableDeallocOp::isStaticArrayVariable() { + return RawVariableOp::isStaticArrayVariable(getVariable().getType()); +} + +bool RawVariableDeallocOp::isDynamicArrayVariable() { + return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); +} +} // namespace mlir::bmodelica + //===---------------------------------------------------------------------===// // RawVariableGetOp -namespace mlir::bmodelica -{ - void RawVariableGetOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Read::get(), - getVariable(), - mlir::SideEffects::DefaultResource::get()); - } - - mlir::Type RawVariableGetOp::computeResultType(mlir::Type rawVariableType) - { - mlir::Type resultType = rawVariableType; - bool isScalar = RawVariableOp::isScalarVariable(rawVariableType); +namespace mlir::bmodelica { +void RawVariableGetOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Read::get(), getVariable(), + mlir::SideEffects::DefaultResource::get()); +} - if (isScalar) { - auto shapedType = rawVariableType.cast(); - resultType = shapedType.getElementType(); - } +mlir::Type RawVariableGetOp::computeResultType(mlir::Type rawVariableType) { + mlir::Type resultType = rawVariableType; + bool isScalar = RawVariableOp::isScalarVariable(rawVariableType); - return resultType; + if (isScalar) { + auto shapedType = rawVariableType.cast(); + resultType = shapedType.getElementType(); } - bool RawVariableGetOp::isScalarVariable() - { - return RawVariableOp::isScalarVariable(getVariable().getType()); - } + return resultType; +} - bool RawVariableGetOp::isStaticArrayVariable() - { - return RawVariableOp::isStaticArrayVariable(getVariable().getType()); - } +bool RawVariableGetOp::isScalarVariable() { + return RawVariableOp::isScalarVariable(getVariable().getType()); +} - bool RawVariableGetOp::isDynamicArrayVariable() - { - return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); - } +bool RawVariableGetOp::isStaticArrayVariable() { + return RawVariableOp::isStaticArrayVariable(getVariable().getType()); } +bool RawVariableGetOp::isDynamicArrayVariable() { + return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); +} +} // namespace mlir::bmodelica + //===---------------------------------------------------------------------===// // RawVariableSetOp -namespace mlir::bmodelica -{ - void RawVariableSetOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - getVariable(), - mlir::SideEffects::DefaultResource::get()); - } +namespace mlir::bmodelica { +void RawVariableSetOp::getEffects( + mlir::SmallVectorImpl< + mlir::SideEffects::EffectInstance> + &effects) { + effects.emplace_back(mlir::MemoryEffects::Write::get(), getVariable(), + mlir::SideEffects::DefaultResource::get()); +} - bool RawVariableSetOp::isScalarVariable() - { - return RawVariableOp::isScalarVariable(getVariable().getType()); - } +bool RawVariableSetOp::isScalarVariable() { + return RawVariableOp::isScalarVariable(getVariable().getType()); +} - bool RawVariableSetOp::isStaticArrayVariable() - { - return RawVariableOp::isStaticArrayVariable(getVariable().getType()); - } +bool RawVariableSetOp::isStaticArrayVariable() { + return RawVariableOp::isStaticArrayVariable(getVariable().getType()); +} - bool RawVariableSetOp::isDynamicArrayVariable() - { - return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); - } +bool RawVariableSetOp::isDynamicArrayVariable() { + return RawVariableOp::isDynamicArrayVariable(getVariable().getType()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // CallOp -namespace mlir::bmodelica -{ - void CallOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - FunctionOp callee, - mlir::ValueRange args, - std::optional argNames) - { - mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); - build(builder, state, symbol, callee.getResultTypes(), args, argNames); - } - - void CallOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - RawFunctionOp callee, - mlir::ValueRange args) - { - mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); - build(builder, state, symbol, callee.getResultTypes(), args); - } - - void CallOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - EquationFunctionOp callee, - mlir::ValueRange args) - { - mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); - build(builder, state, symbol, callee.getResultTypes(), args); - } - - void CallOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - mlir::SymbolRefAttr callee, - mlir::TypeRange resultTypes, - mlir::ValueRange args, - std::optional argNames) - { - state.addOperands(args); - state.addAttribute(getCalleeAttrName(state.name), callee); - - if (argNames) { - state.addAttribute(getArgNamesAttrName(state.name), *argNames); - } - - state.addTypes(resultTypes); - } - - mlir::LogicalResult CallOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTable) - { - auto moduleOp = getOperation()->getParentOfType(); - mlir::Operation* callee = getFunction(moduleOp, symbolTable); - - if (!callee) { - // TODO - // At the moment verification would fail for derivatives of functions, - // because they are declared through an attribute. We should look into - // turning that attribute into an operation, so that the symbol becomes - // declared within the module, and thus obtainable. - return mlir::success(); - } +namespace mlir::bmodelica { +void CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + FunctionOp callee, mlir::ValueRange args, + std::optional argNames) { + mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); + build(builder, state, symbol, callee.getResultTypes(), args, argNames); +} - if (mlir::isa(callee)) { - // TODO implement proper verification of DerFunctionOp function type. - return mlir::success(); - } +void CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + RawFunctionOp callee, mlir::ValueRange args) { + mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); + build(builder, state, symbol, callee.getResultTypes(), args); +} + +void CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + EquationFunctionOp callee, mlir::ValueRange args) { + mlir::SymbolRefAttr symbol = getSymbolRefFromRoot(callee); + build(builder, state, symbol, callee.getResultTypes(), args); +} - // Verify that the operand and result types match the callee. - if (auto functionOp = mlir::dyn_cast(callee)) { - mlir::FunctionType functionType = functionOp.getFunctionType(); +void CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::SymbolRefAttr callee, mlir::TypeRange resultTypes, + mlir::ValueRange args, + std::optional argNames) { + state.addOperands(args); + state.addAttribute(getCalleeAttrName(state.name), callee); - llvm::SmallVector inputVariables; - llvm::DenseSet inputVariablesSet; + if (argNames) { + state.addAttribute(getArgNamesAttrName(state.name), *argNames); + } - for (VariableOp variableOp : functionOp.getVariables()) { - mlir::StringAttr variableName = variableOp.getSymNameAttr(); + state.addTypes(resultTypes); +} - if (variableOp.isInput()) { - inputVariables.push_back(variableName); - inputVariablesSet.insert(variableName); - } - } +mlir::LogicalResult +CallOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) { + auto moduleOp = getOperation()->getParentOfType(); + mlir::Operation *callee = getFunction(moduleOp, symbolTable); + + if (!callee) { + // TODO + // At the moment verification would fail for derivatives of functions, + // because they are declared through an attribute. We should look into + // turning that attribute into an operation, so that the symbol becomes + // declared within the module, and thus obtainable. + return mlir::success(); + } - llvm::DenseSet variablesWithDefaultValue; + if (mlir::isa(callee)) { + // TODO implement proper verification of DerFunctionOp function type. + return mlir::success(); + } + + // Verify that the operand and result types match the callee. + if (auto functionOp = mlir::dyn_cast(callee)) { + mlir::FunctionType functionType = functionOp.getFunctionType(); + + llvm::SmallVector inputVariables; + llvm::DenseSet inputVariablesSet; - for (DefaultOp defaultOp : functionOp.getDefaultValues()) { - variablesWithDefaultValue.insert(defaultOp.getVariableAttr()); + for (VariableOp variableOp : functionOp.getVariables()) { + mlir::StringAttr variableName = variableOp.getSymNameAttr(); + + if (variableOp.isInput()) { + inputVariables.push_back(variableName); + inputVariablesSet.insert(variableName); } + } - auto args = getArgs(); + llvm::DenseSet variablesWithDefaultValue; - if (auto argNames = getArgNames()) { - if (argNames->size() != args.size()) { - return emitOpError() - << "the number of arguments (" << args.size() - << ") does not match the number of argument names (" - << argNames->size() << ")"; - } + for (DefaultOp defaultOp : functionOp.getDefaultValues()) { + variablesWithDefaultValue.insert(defaultOp.getVariableAttr()); + } - llvm::DenseSet specifiedInputs; + auto args = getArgs(); - for (mlir::FlatSymbolRefAttr argName : - argNames->getAsRange()) { - if (!inputVariablesSet.contains(argName.getAttr())) { - return emitOpError() << "unknown argument '" - << argName.getValue() << "'"; - } + if (auto argNames = getArgNames()) { + if (argNames->size() != args.size()) { + return emitOpError() + << "the number of arguments (" << args.size() + << ") does not match the number of argument names (" + << argNames->size() << ")"; + } - if (specifiedInputs.contains(argName.getAttr())) { - return emitOpError() << "multiple values for argument '" - << argName.getValue() << "'"; - } + llvm::DenseSet specifiedInputs; - specifiedInputs.insert(argName.getAttr()); + for (mlir::FlatSymbolRefAttr argName : + argNames->getAsRange()) { + if (!inputVariablesSet.contains(argName.getAttr())) { + return emitOpError() + << "unknown argument '" << argName.getValue() << "'"; } - for (mlir::StringAttr variableName : inputVariables) { - if (!variablesWithDefaultValue.contains(variableName) && - !specifiedInputs.contains(variableName)) { - return emitOpError() << "missing value for argument '" - << variableName.getValue() << "'"; - } - } - } else { - if (args.size() > inputVariables.size()) { - return emitOpError() << "too many arguments specified (expected " - << inputVariables.size() << ", got " << args.size() << ")"; + if (specifiedInputs.contains(argName.getAttr())) { + return emitOpError() << "multiple values for argument '" + << argName.getValue() << "'"; } - for (mlir::StringAttr variableName : - llvm::ArrayRef(inputVariables).drop_front(args.size())) { - if (!variablesWithDefaultValue.contains(variableName)) { - return emitOpError() << "missing value for argument '" - << variableName.getValue() << "'"; - } - } + specifiedInputs.insert(argName.getAttr()); } - unsigned int expectedResults = functionType.getNumResults(); - unsigned int actualResults = getNumResults(); - - if (expectedResults != actualResults) { + for (mlir::StringAttr variableName : inputVariables) { + if (!variablesWithDefaultValue.contains(variableName) && + !specifiedInputs.contains(variableName)) { + return emitOpError() << "missing value for argument '" + << variableName.getValue() << "'"; + } + } + } else { + if (args.size() > inputVariables.size()) { return emitOpError() - << "incorrect number of results for callee (expected " - << expectedResults << ", got " << actualResults << ")"; + << "too many arguments specified (expected " + << inputVariables.size() << ", got " << args.size() << ")"; + } - return mlir::failure(); + for (mlir::StringAttr variableName : + llvm::ArrayRef(inputVariables).drop_front(args.size())) { + if (!variablesWithDefaultValue.contains(variableName)) { + return emitOpError() << "missing value for argument '" + << variableName.getValue() << "'"; + } } + } - return mlir::success(); + unsigned int expectedResults = functionType.getNumResults(); + unsigned int actualResults = getNumResults(); + + if (expectedResults != actualResults) { + return emitOpError() + << "incorrect number of results for callee (expected " + << expectedResults << ", got " << actualResults << ")"; + + return mlir::failure(); } - if (auto equationFunctionOp = mlir::dyn_cast(callee)) { - mlir::FunctionType functionType = equationFunctionOp.getFunctionType(); + return mlir::success(); + } - unsigned int expectedInputs = functionType.getNumInputs(); - unsigned int actualInputs = getNumOperands(); + if (auto equationFunctionOp = mlir::dyn_cast(callee)) { + mlir::FunctionType functionType = equationFunctionOp.getFunctionType(); - if (expectedInputs != actualInputs) { - return emitOpError() - << "incorrect number of operands for callee (expected " - << expectedInputs << ", got " - << actualInputs << ")"; - } + unsigned int expectedInputs = functionType.getNumInputs(); + unsigned int actualInputs = getNumOperands(); - unsigned int expectedResults = functionType.getNumResults(); - unsigned int actualResults = getNumResults(); + if (expectedInputs != actualInputs) { + return emitOpError() + << "incorrect number of operands for callee (expected " + << expectedInputs << ", got " << actualInputs << ")"; + } - if (expectedResults != actualResults) { - return emitOpError() - << "incorrect number of results for callee (expected " - << expectedResults << ", got " - << actualResults << ")"; - } + unsigned int expectedResults = functionType.getNumResults(); + unsigned int actualResults = getNumResults(); - return mlir::success(); + if (expectedResults != actualResults) { + return emitOpError() + << "incorrect number of results for callee (expected " + << expectedResults << ", got " << actualResults << ")"; } - if (auto rawFunctionOp = mlir::dyn_cast(callee)) { - mlir::FunctionType functionType = rawFunctionOp.getFunctionType(); + return mlir::success(); + } - unsigned int expectedInputs = functionType.getNumInputs(); - unsigned int actualInputs = getNumOperands(); + if (auto rawFunctionOp = mlir::dyn_cast(callee)) { + mlir::FunctionType functionType = rawFunctionOp.getFunctionType(); - if (expectedInputs != actualInputs) { - return emitOpError() - << "incorrect number of operands for callee (expected " - << expectedInputs << ", got " - << actualInputs << ")"; - } + unsigned int expectedInputs = functionType.getNumInputs(); + unsigned int actualInputs = getNumOperands(); - unsigned int expectedResults = functionType.getNumResults(); - unsigned int actualResults = getNumResults(); + if (expectedInputs != actualInputs) { + return emitOpError() + << "incorrect number of operands for callee (expected " + << expectedInputs << ", got " << actualInputs << ")"; + } - if (expectedResults != actualResults) { - return emitOpError() - << "incorrect number of results for callee (expected " - << expectedResults << ", got " - << actualResults << ")"; - } + unsigned int expectedResults = functionType.getNumResults(); + unsigned int actualResults = getNumResults(); - return mlir::success(); + if (expectedResults != actualResults) { + return emitOpError() + << "incorrect number of results for callee (expected " + << expectedResults << ", got " << actualResults << ")"; } - return emitOpError() << "'" << getCallee() - << "' does not reference a valid function"; + return mlir::success(); } - void CallOp::getEffects( - mlir::SmallVectorImpl< - mlir::SideEffects::EffectInstance< - mlir::MemoryEffects::Effect>>& effects) - { - // The callee may have no arguments and no results, but still have side - // effects (i.e. an external function writing elsewhere). Thus we need to - // consider the call itself as if it is has side effects and prevent the - // CSE pass to erase it. - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - mlir::SideEffects::DefaultResource::get()); + return emitOpError() << "'" << getCallee() + << "' does not reference a valid function"; +} - for (mlir::Value result : getResults()) { - if (auto arrayType = result.getType().dyn_cast()) { - effects.emplace_back( - mlir::MemoryEffects::Allocate::get(), - result, - mlir::SideEffects::DefaultResource::get()); +void CallOp::getEffects(mlir::SmallVectorImpl> &effects) { + // The callee may have no arguments and no results, but still have side + // effects (i.e. an external function writing elsewhere). Thus we need to + // consider the call itself as if it is has side effects and prevent the + // CSE pass to erase it. + effects.emplace_back(mlir::MemoryEffects::Write::get(), + mlir::SideEffects::DefaultResource::get()); - effects.emplace_back( - mlir::MemoryEffects::Write::get(), - result, - mlir::SideEffects::DefaultResource::get()); - } + for (mlir::Value result : getResults()) { + if (auto arrayType = result.getType().dyn_cast()) { + effects.emplace_back(mlir::MemoryEffects::Allocate::get(), result, + mlir::SideEffects::DefaultResource::get()); + + effects.emplace_back(mlir::MemoryEffects::Write::get(), result, + mlir::SideEffects::DefaultResource::get()); } } +} - mlir::Operation* CallOp::getFunction( - mlir::ModuleOp moduleOp, mlir::SymbolTableCollection& symbolTable) - { - mlir::SymbolRefAttr callee = getCallee(); - - mlir::Operation* result = symbolTable.lookupSymbolIn( - moduleOp, callee.getRootReference()); +mlir::Operation *CallOp::getFunction(mlir::ModuleOp moduleOp, + mlir::SymbolTableCollection &symbolTable) { + mlir::SymbolRefAttr callee = getCallee(); - for (mlir::FlatSymbolRefAttr flatSymbolRef : callee.getNestedReferences()) { - if (result == nullptr) { - return nullptr; - } + mlir::Operation *result = + symbolTable.lookupSymbolIn(moduleOp, callee.getRootReference()); - result = symbolTable.lookupSymbolIn(result, flatSymbolRef.getAttr()); + for (mlir::FlatSymbolRefAttr flatSymbolRef : callee.getNestedReferences()) { + if (result == nullptr) { + return nullptr; } - return result; + result = symbolTable.lookupSymbolIn(result, flatSymbolRef.getAttr()); } + + return result; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ScheduleOp -namespace mlir::bmodelica -{ - void ScheduleOp::collectSCCGroups( - llvm::SmallVectorImpl& SCCGroups) - { - for (SCCGroupOp sccGroup : getOps()) { - SCCGroups.push_back(sccGroup); - } +namespace mlir::bmodelica { +void ScheduleOp::collectSCCGroups( + llvm::SmallVectorImpl &SCCGroups) { + for (SCCGroupOp sccGroup : getOps()) { + SCCGroups.push_back(sccGroup); } +} - void ScheduleOp::collectSCCs(llvm::SmallVectorImpl& SCCs) - { - for (SCCOp scc : getOps()) { - SCCs.push_back(scc); - } +void ScheduleOp::collectSCCs(llvm::SmallVectorImpl &SCCs) { + for (SCCOp scc : getOps()) { + SCCs.push_back(scc); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // RunScheduleOp -namespace mlir::bmodelica -{ - void RunScheduleOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - ScheduleOp scheduleOp) - { - auto qualifiedRef = getSymbolRefFromRoot(scheduleOp); - build(builder, state, qualifiedRef); - } - - mlir::LogicalResult RunScheduleOp::verifySymbolUses( - mlir::SymbolTableCollection& symbolTableCollection) - { - auto moduleOp = getOperation()->getParentOfType(); +namespace mlir::bmodelica { +void RunScheduleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + ScheduleOp scheduleOp) { + auto qualifiedRef = getSymbolRefFromRoot(scheduleOp); + build(builder, state, qualifiedRef); +} - mlir::Operation* symbolOp = - resolveSymbol(moduleOp, symbolTableCollection, getSchedule()); +mlir::LogicalResult RunScheduleOp::verifySymbolUses( + mlir::SymbolTableCollection &symbolTableCollection) { + auto moduleOp = getOperation()->getParentOfType(); - if (!symbolOp) { - return emitError() << "symbol " << getSchedule() << " not found"; - } + mlir::Operation *symbolOp = + resolveSymbol(moduleOp, symbolTableCollection, getSchedule()); - auto scheduleOp = mlir::dyn_cast(symbolOp); + if (!symbolOp) { + return emitError() << "symbol " << getSchedule() << " not found"; + } - if (!scheduleOp) { - return emitError() << "symbol " << getSchedule() << " is not a schedule"; - } + auto scheduleOp = mlir::dyn_cast(symbolOp); - return mlir::success(); + if (!scheduleOp) { + return emitError() << "symbol " << getSchedule() << " is not a schedule"; } - ScheduleOp RunScheduleOp::getScheduleOp() - { - mlir::SymbolTableCollection symbolTableCollection; - return getScheduleOp(symbolTableCollection); - } + return mlir::success(); +} + +ScheduleOp RunScheduleOp::getScheduleOp() { + mlir::SymbolTableCollection symbolTableCollection; + return getScheduleOp(symbolTableCollection); +} - ScheduleOp RunScheduleOp::getScheduleOp( - mlir::SymbolTableCollection& symbolTableCollection) - { - auto moduleOp = getOperation()->getParentOfType(); +ScheduleOp RunScheduleOp::getScheduleOp( + mlir::SymbolTableCollection &symbolTableCollection) { + auto moduleOp = getOperation()->getParentOfType(); - mlir::Operation* variable = - resolveSymbol(moduleOp, symbolTableCollection, getSchedule()); + mlir::Operation *variable = + resolveSymbol(moduleOp, symbolTableCollection, getSchedule()); - return mlir::dyn_cast(variable); - } + return mlir::dyn_cast(variable); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Control flow operations @@ -12326,251 +11280,223 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // ForOp -namespace mlir::bmodelica -{ - llvm::SmallVector ForOp::getLoopRegions() - { - llvm::SmallVector result; - result.push_back(&getBodyRegion()); - return result; - } - - mlir::Block* ForOp::conditionBlock() - { - assert(!getConditionRegion().empty()); - return &getConditionRegion().front(); - } +namespace mlir::bmodelica { +llvm::SmallVector ForOp::getLoopRegions() { + llvm::SmallVector result; + result.push_back(&getBodyRegion()); + return result; +} - mlir::Block* ForOp::bodyBlock() - { - assert(!getBodyRegion().empty()); - return &getBodyRegion().front(); - } +mlir::Block *ForOp::conditionBlock() { + assert(!getConditionRegion().empty()); + return &getConditionRegion().front(); +} - mlir::Block* ForOp::stepBlock() - { - assert(!getStepRegion().empty()); - return &getStepRegion().front(); - } +mlir::Block *ForOp::bodyBlock() { + assert(!getBodyRegion().empty()); + return &getBodyRegion().front(); +} - mlir::ParseResult ForOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::Region* conditionRegion = result.addRegion(); +mlir::Block *ForOp::stepBlock() { + assert(!getStepRegion().empty()); + return &getStepRegion().front(); +} - if (mlir::succeeded(parser.parseOptionalLParen())) { - if (mlir::failed(parser.parseOptionalRParen())) { - do { - mlir::OpAsmParser::UnresolvedOperand arg; - mlir::Type argType; +mlir::ParseResult ForOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::Region *conditionRegion = result.addRegion(); - if (parser.parseOperand(arg) || - parser.parseColonType(argType) || - parser.resolveOperand(arg, argType, result.operands)) - return mlir::failure(); - } while (mlir::succeeded(parser.parseOptionalComma())); - } + if (mlir::succeeded(parser.parseOptionalLParen())) { + if (mlir::failed(parser.parseOptionalRParen())) { + do { + mlir::OpAsmParser::UnresolvedOperand arg; + mlir::Type argType; - if (parser.parseRParen()) { - return mlir::failure(); - } + if (parser.parseOperand(arg) || parser.parseColonType(argType) || + parser.resolveOperand(arg, argType, result.operands)) + return mlir::failure(); + } while (mlir::succeeded(parser.parseOptionalComma())); } - if (parser.parseKeyword("condition")) { + if (parser.parseRParen()) { return mlir::failure(); } + } - if (parser.parseRegion(*conditionRegion)) { - return mlir::failure(); - } + if (parser.parseKeyword("condition")) { + return mlir::failure(); + } - if (parser.parseKeyword("body")) { - return mlir::failure(); - } + if (parser.parseRegion(*conditionRegion)) { + return mlir::failure(); + } - mlir::Region* bodyRegion = result.addRegion(); + if (parser.parseKeyword("body")) { + return mlir::failure(); + } - if (parser.parseRegion(*bodyRegion)) { - return mlir::failure(); - } + mlir::Region *bodyRegion = result.addRegion(); - if (parser.parseKeyword("step")) { - return mlir::failure(); - } + if (parser.parseRegion(*bodyRegion)) { + return mlir::failure(); + } - mlir::Region* stepRegion = result.addRegion(); + if (parser.parseKeyword("step")) { + return mlir::failure(); + } - if (parser.parseRegion(*stepRegion)) { - return mlir::failure(); - } + mlir::Region *stepRegion = result.addRegion(); - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); - } + if (parser.parseRegion(*stepRegion)) { + return mlir::failure(); + } - return mlir::success(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); } - void ForOp::print(mlir::OpAsmPrinter& printer) - { - if (auto values = getArgs(); !values.empty()) { - printer << "("; + return mlir::success(); +} - for (auto arg : llvm::enumerate(values)) { - if (arg.index() != 0) { - printer << ", "; - } +void ForOp::print(mlir::OpAsmPrinter &printer) { + if (auto values = getArgs(); !values.empty()) { + printer << "("; - printer << arg.value() << " : " << arg.value().getType(); + for (auto arg : llvm::enumerate(values)) { + if (arg.index() != 0) { + printer << ", "; } - printer << ")"; + printer << arg.value() << " : " << arg.value().getType(); } - printer << " condition "; - printer.printRegion(getConditionRegion(), true); - printer << " body "; - printer.printRegion(getBodyRegion(), true); - printer << " step "; - printer.printRegion(getStepRegion(), true); - printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); + printer << ")"; } + + printer << " condition "; + printer.printRegion(getConditionRegion(), true); + printer << " body "; + printer.printRegion(getBodyRegion(), true); + printer << " step "; + printer.printRegion(getStepRegion(), true); + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // IfOp -namespace mlir::bmodelica -{ - void IfOp::build( - mlir::OpBuilder& builder, - mlir::OperationState& state, - mlir::Value condition, - bool withElseRegion) - { - mlir::OpBuilder::InsertionGuard guard(builder); - state.addOperands(condition); +namespace mlir::bmodelica { +void IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value condition, bool withElseRegion) { + mlir::OpBuilder::InsertionGuard guard(builder); + state.addOperands(condition); - // Create the "then" region. - mlir::Region* thenRegion = state.addRegion(); - builder.createBlock(thenRegion); + // Create the "then" region. + mlir::Region *thenRegion = state.addRegion(); + builder.createBlock(thenRegion); - // Create the "else" region. - mlir::Region* elseRegion = state.addRegion(); + // Create the "else" region. + mlir::Region *elseRegion = state.addRegion(); - if (withElseRegion) { - builder.createBlock(elseRegion); - } + if (withElseRegion) { + builder.createBlock(elseRegion); } +} - mlir::Block* IfOp::thenBlock() - { - return &getThenRegion().front(); - } +mlir::Block *IfOp::thenBlock() { return &getThenRegion().front(); } - mlir::Block* IfOp::elseBlock() - { - return &getElseRegion().front(); - } +mlir::Block *IfOp::elseBlock() { return &getElseRegion().front(); } - mlir::ParseResult IfOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::OpAsmParser::UnresolvedOperand condition; - mlir::Type conditionType; +mlir::ParseResult IfOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand condition; + mlir::Type conditionType; - if (parser.parseLParen() || - parser.parseOperand(condition) || - parser.parseColonType(conditionType) || - parser.parseRParen() || - parser.resolveOperand(condition, conditionType, result.operands)) { - return mlir::failure(); - } + if (parser.parseLParen() || parser.parseOperand(condition) || + parser.parseColonType(conditionType) || parser.parseRParen() || + parser.resolveOperand(condition, conditionType, result.operands)) { + return mlir::failure(); + } - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); - } + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); + } - mlir::Region* thenRegion = result.addRegion(); + mlir::Region *thenRegion = result.addRegion(); - if (parser.parseRegion(*thenRegion)) { - return mlir::failure(); - } + if (parser.parseRegion(*thenRegion)) { + return mlir::failure(); + } - mlir::Region* elseRegion = result.addRegion(); + mlir::Region *elseRegion = result.addRegion(); - if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { - if (parser.parseRegion(*elseRegion)) { - return mlir::failure(); - } + if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { + if (parser.parseRegion(*elseRegion)) { + return mlir::failure(); } - - return mlir::success(); } - void IfOp::print(mlir::OpAsmPrinter& printer) - { - printer << " (" << getCondition() << " : " - << getCondition().getType() << ") "; + return mlir::success(); +} + +void IfOp::print(mlir::OpAsmPrinter &printer) { + printer << " (" << getCondition() << " : " << getCondition().getType() + << ") "; - printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - printer.printRegion(getThenRegion()); + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); + printer.printRegion(getThenRegion()); - if (!getElseRegion().empty()) { - printer << " else "; - printer.printRegion(getElseRegion()); - } + if (!getElseRegion().empty()) { + printer << " else "; + printer.printRegion(getElseRegion()); } } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // WhileOp -namespace mlir::bmodelica -{ - mlir::ParseResult WhileOp::parse( - mlir::OpAsmParser& parser, mlir::OperationState& result) - { - mlir::Region* conditionRegion = result.addRegion(); - mlir::Region* bodyRegion = result.addRegion(); - - if (parser.parseRegion(*conditionRegion) || - parser.parseKeyword("do") || - parser.parseRegion(*bodyRegion)) { - return mlir::failure(); - } - - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { - return mlir::failure(); - } - - if (conditionRegion->empty()) { - conditionRegion->emplaceBlock(); - } +namespace mlir::bmodelica { +mlir::ParseResult WhileOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::Region *conditionRegion = result.addRegion(); + mlir::Region *bodyRegion = result.addRegion(); - if (bodyRegion->empty()) { - bodyRegion->emplaceBlock(); - } + if (parser.parseRegion(*conditionRegion) || parser.parseKeyword("do") || + parser.parseRegion(*bodyRegion)) { + return mlir::failure(); + } - return mlir::success(); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { + return mlir::failure(); } - void WhileOp::print(mlir::OpAsmPrinter& printer) - { - printer << " "; - printer.printRegion(getConditionRegion(), false); - printer << " do "; - printer.printRegion(getBodyRegion(), false); - printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); + if (conditionRegion->empty()) { + conditionRegion->emplaceBlock(); } - llvm::SmallVector WhileOp::getLoopRegions() - { - llvm::SmallVector result; - result.push_back(&getBodyRegion()); - return result; + if (bodyRegion->empty()) { + bodyRegion->emplaceBlock(); } + + return mlir::success(); +} + +void WhileOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printRegion(getConditionRegion(), false); + printer << " do "; + printer.printRegion(getBodyRegion(), false); + printer.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); +} + +llvm::SmallVector WhileOp::getLoopRegions() { + llvm::SmallVector result; + result.push_back(&getBodyRegion()); + return result; } +} // namespace mlir::bmodelica //===---------------------------------------------------------------------===// // Utility operations @@ -12579,30 +11505,28 @@ namespace mlir::bmodelica //===---------------------------------------------------------------------===// // CastOp -namespace mlir::bmodelica -{ - mlir::OpFoldResult CastOp::fold(FoldAdaptor adaptor) - { - auto operand = adaptor.getValue(); - - if (!operand) { - return {}; - } +namespace mlir::bmodelica { +mlir::OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + auto operand = adaptor.getValue(); - auto resultType = getResult().getType(); + if (!operand) { + return {}; + } - if (isScalar(operand)) { - if (isScalarIntegerLike(operand)) { - int64_t value = getScalarIntegerLikeValue(operand); - return getAttr(resultType, value); - } + auto resultType = getResult().getType(); - if (isScalarFloatLike(operand)) { - double value = getScalarFloatLikeValue(operand); - return getAttr(resultType, value); - } + if (isScalar(operand)) { + if (isScalarIntegerLike(operand)) { + int64_t value = getScalarIntegerLikeValue(operand); + return getAttr(resultType, value); } - return {}; + if (isScalarFloatLike(operand)) { + double value = getScalarFloatLikeValue(operand); + return getAttr(resultType, value); + } } + + return {}; } +} // namespace mlir::bmodelica diff --git a/lib/Dialect/BaseModelica/IR/Types.cpp b/lib/Dialect/BaseModelica/IR/Types.cpp index c964c011d..1eeb9bef8 100644 --- a/lib/Dialect/BaseModelica/IR/Types.cpp +++ b/lib/Dialect/BaseModelica/IR/Types.cpp @@ -19,638 +19,518 @@ using namespace ::mlir::bmodelica::detail; // ModelicaDialect //===---------------------------------------------------------------------===// -namespace mlir::bmodelica -{ - void BaseModelicaDialect::registerTypes() - { - addTypes< - #define GET_TYPEDEF_LIST - #include "marco/Dialect/BaseModelica/IR/BaseModelicaTypes.cpp.inc" - >(); - } +namespace mlir::bmodelica { +void BaseModelicaDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "marco/Dialect/BaseModelica/IR/BaseModelicaTypes.cpp.inc" + >(); } +} // namespace mlir::bmodelica -namespace mlir::bmodelica -{ - //===-------------------------------------------------------------------===// - // BaseArrayType - //===-------------------------------------------------------------------===// +namespace mlir::bmodelica { +//===-------------------------------------------------------------------===// +// BaseArrayType +//===-------------------------------------------------------------------===// - bool BaseArrayType::classof(mlir::Type type) - { - return type.isa(); - } - - BaseArrayType::operator mlir::ShapedType() const - { - return cast(); - } +bool BaseArrayType::classof(mlir::Type type) { + return type.isa(); +} - bool BaseArrayType::isValidElementType(mlir::Type type) - { - return type.isIndex() || type.isIntOrFloat() || - type.isa(); - } +BaseArrayType::operator mlir::ShapedType() const { + return cast(); +} - mlir::Type BaseArrayType::getElementType() const - { - return llvm::TypeSwitch(*this) - .Case( - [](auto type) { - return type.getElementType(); - }); - } +bool BaseArrayType::isValidElementType(mlir::Type type) { + return type.isIndex() || type.isIntOrFloat() || + type.isa(); +} - bool BaseArrayType::hasRank() const - { - return !isa(); - } +mlir::Type BaseArrayType::getElementType() const { + return llvm::TypeSwitch(*this) + .Case( + [](auto type) { return type.getElementType(); }); +} - llvm::ArrayRef BaseArrayType::getShape() const - { - return cast().getShape(); - } +bool BaseArrayType::hasRank() const { return !isa(); } - mlir::Attribute BaseArrayType::getMemorySpace() const - { - if (auto rankedArrayTy = dyn_cast()) { - return rankedArrayTy.getMemorySpace(); - } +llvm::ArrayRef BaseArrayType::getShape() const { + return cast().getShape(); +} - return cast().getMemorySpace(); +mlir::Attribute BaseArrayType::getMemorySpace() const { + if (auto rankedArrayTy = dyn_cast()) { + return rankedArrayTy.getMemorySpace(); } - BaseArrayType BaseArrayType::cloneWith( - std::optional> shape, - mlir::Type elementType) const - { - if (isa()) { - if (!shape) { - return UnrankedArrayType::get(elementType, getMemorySpace()); - } + return cast().getMemorySpace(); +} - ArrayType::Builder builder(*shape, elementType); - builder.setMemorySpace(getMemorySpace()); - return builder; +BaseArrayType +BaseArrayType::cloneWith(std::optional> shape, + mlir::Type elementType) const { + if (isa()) { + if (!shape) { + return UnrankedArrayType::get(elementType, getMemorySpace()); } - ArrayType::Builder builder(cast()); + ArrayType::Builder builder(*shape, elementType); + builder.setMemorySpace(getMemorySpace()); + return builder; + } - if (shape) { - builder.setShape(*shape); - } + ArrayType::Builder builder(cast()); - builder.setElementType(elementType); - return builder; + if (shape) { + builder.setShape(*shape); } - //===-------------------------------------------------------------------===// - // ArrayType - //===-------------------------------------------------------------------===// + builder.setElementType(elementType); + return builder; +} - mlir::Type ArrayType::parse(mlir::AsmParser& parser) - { - llvm::SmallVector dimensions; +//===-------------------------------------------------------------------===// +// ArrayType +//===-------------------------------------------------------------------===// - mlir::Type elementType; - mlir::Attribute memorySpace; +mlir::Type ArrayType::parse(mlir::AsmParser &parser) { + llvm::SmallVector dimensions; - if (parser.parseLess() || - parser.parseDimensionList(dimensions) || - parser.parseType(elementType)) { - return {}; - } + mlir::Type elementType; + mlir::Attribute memorySpace; - if (mlir::succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(memorySpace)) { - return {}; - } - } + if (parser.parseLess() || parser.parseDimensionList(dimensions) || + parser.parseType(elementType)) { + return {}; + } - if (parser.parseGreater()) { + if (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(memorySpace)) { return {}; } - - return ArrayType::get(dimensions, elementType, memorySpace); } - void ArrayType::print(mlir::AsmPrinter& printer) const - { - printer << "<"; + if (parser.parseGreater()) { + return {}; + } - for (int64_t dimension : getShape()) { - if (dimension == ArrayType::kDynamic) { - printer << "?"; - } else { - printer << dimension; - } + return ArrayType::get(dimensions, elementType, memorySpace); +} - printer << "x"; - } +void ArrayType::print(mlir::AsmPrinter &printer) const { + printer << "<"; - printer << getElementType() << ">"; - // TODO print memory space - } - - ArrayType ArrayType::get( - llvm::ArrayRef shape, - mlir::Type elementType, - mlir::Attribute memorySpace) - { - // Drop default memory space value and replace it with empty attribute. - memorySpace = skipDefaultMemorySpace(memorySpace); - - return Base::get( - elementType.getContext(), - shape, - elementType, - memorySpace); - } - - ArrayType ArrayType::getChecked( - llvm::function_ref emitErrorFn, - llvm::ArrayRef shape, - mlir::Type elementType, - mlir::Attribute memorySpace) - { - // Drop default memory space value and replace it with empty attribute. - memorySpace = skipDefaultMemorySpace(memorySpace); - - return Base::getChecked( - emitErrorFn, - elementType.getContext(), - shape, - elementType, - memorySpace); - } - - mlir::LogicalResult ArrayType::verify( - llvm::function_ref emitError, - llvm::ArrayRef shape, - mlir::Type elementType, - mlir::Attribute memorySpace) - { - if (!BaseArrayType::isValidElementType(elementType)) { - return emitError() << "invalid array element type"; + for (int64_t dimension : getShape()) { + if (dimension == ArrayType::kDynamic) { + printer << "?"; + } else { + printer << dimension; } - // Negative sizes are not allowed. - for (int64_t size : shape) { - if (size < 0 && size != ArrayType::kDynamic) { - return emitError() << "invalid array size"; - } - } + printer << "x"; + } - if (!isSupportedMemorySpace(memorySpace)) { - return emitError() << "unsupported memory space Attribute"; - } + printer << getElementType() << ">"; + // TODO print memory space +} - return mlir::success(); - } +ArrayType ArrayType::get(llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute memorySpace) { + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); - bool ArrayType::isScalar() const - { - return getRank() == 0; - } + return Base::get(elementType.getContext(), shape, elementType, memorySpace); +} - ArrayType ArrayType::slice(unsigned int subscriptsAmount) const - { - auto shape = getShape(); - assert(subscriptsAmount <= shape.size() && "Too many subscriptions"); - llvm::SmallVector resultShape; +ArrayType ArrayType::getChecked( + llvm::function_ref emitErrorFn, + llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute memorySpace) { + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); - for (size_t i = subscriptsAmount, e = shape.size(); i < e; ++i) { - resultShape.push_back(shape[i]); - } + return Base::getChecked(emitErrorFn, elementType.getContext(), shape, + elementType, memorySpace); +} - return ArrayType::get(resultShape, getElementType()); +mlir::LogicalResult +ArrayType::verify(llvm::function_ref emitError, + llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute memorySpace) { + if (!BaseArrayType::isValidElementType(elementType)) { + return emitError() << "invalid array element type"; } - ArrayType ArrayType::toElementType(mlir::Type elementType) const - { - return ArrayType::get(getShape(), elementType); + // Negative sizes are not allowed. + for (int64_t size : shape) { + if (size < 0 && size != ArrayType::kDynamic) { + return emitError() << "invalid array size"; + } } - ArrayType ArrayType::withShape(llvm::ArrayRef shape) const - { - return ArrayType::get(shape, getElementType()); + if (!isSupportedMemorySpace(memorySpace)) { + return emitError() << "unsupported memory space Attribute"; } - bool ArrayType::canBeOnStack() const - { - return hasStaticShape(); + return mlir::success(); +} + +bool ArrayType::isScalar() const { return getRank() == 0; } + +ArrayType ArrayType::slice(unsigned int subscriptsAmount) const { + auto shape = getShape(); + assert(subscriptsAmount <= shape.size() && "Too many subscriptions"); + llvm::SmallVector resultShape; + + for (size_t i = subscriptsAmount, e = shape.size(); i < e; ++i) { + resultShape.push_back(shape[i]); } - //===-------------------------------------------------------------------===// - // UnrankedArrayType - //===-------------------------------------------------------------------===// + return ArrayType::get(resultShape, getElementType()); +} - mlir::Type UnrankedArrayType::parse(mlir::AsmParser& parser) - { - mlir::Type elementType; - mlir::Attribute memorySpace; +ArrayType ArrayType::toElementType(mlir::Type elementType) const { + return ArrayType::get(getShape(), elementType); +} - if (parser.parseLess() || - parser.parseType(elementType)) { - return {}; - } +ArrayType ArrayType::withShape(llvm::ArrayRef shape) const { + return ArrayType::get(shape, getElementType()); +} - if (mlir::succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(memorySpace)) { - return {}; - } - } +bool ArrayType::canBeOnStack() const { return hasStaticShape(); } + +//===-------------------------------------------------------------------===// +// UnrankedArrayType +//===-------------------------------------------------------------------===// + +mlir::Type UnrankedArrayType::parse(mlir::AsmParser &parser) { + mlir::Type elementType; + mlir::Attribute memorySpace; - if (parser.parseGreater()) { + if (parser.parseLess() || parser.parseType(elementType)) { + return {}; + } + + if (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(memorySpace)) { return {}; } - - return UnrankedArrayType::get(elementType, memorySpace); } - void UnrankedArrayType::print(mlir::AsmPrinter& printer) const - { - printer << "<" << getElementType() << ">"; - // TODO print memory space + if (parser.parseGreater()) { + return {}; } - mlir::LogicalResult UnrankedArrayType::verify( - llvm::function_ref emitError, - mlir::Type elementType, - mlir::Attribute memorySpace) - { - if (!BaseArrayType::isValidElementType(elementType)) { - return emitError() << "invalid array element type"; - } + return UnrankedArrayType::get(elementType, memorySpace); +} - if (!isSupportedMemorySpace(memorySpace)) { - return emitError() << "unsupported memory space Attribute"; - } +void UnrankedArrayType::print(mlir::AsmPrinter &printer) const { + printer << "<" << getElementType() << ">"; + // TODO print memory space +} - return mlir::success(); +mlir::LogicalResult UnrankedArrayType::verify( + llvm::function_ref emitError, + mlir::Type elementType, mlir::Attribute memorySpace) { + if (!BaseArrayType::isValidElementType(elementType)) { + return emitError() << "invalid array element type"; } - //===-------------------------------------------------------------------===// - // RecordType - //===-------------------------------------------------------------------===// + if (!isSupportedMemorySpace(memorySpace)) { + return emitError() << "unsupported memory space Attribute"; + } - mlir::Operation* RecordType::getRecordOp( - mlir::SymbolTableCollection& symbolTable, - mlir::ModuleOp moduleOp) - { - mlir::Operation* result = moduleOp.getOperation(); - result = symbolTable.lookupSymbolIn(result, getName().getRootReference()); + return mlir::success(); +} - for (mlir::FlatSymbolRefAttr flatSymbolRef : - getName().getNestedReferences()) { - if (result == nullptr) { - return nullptr; - } +//===-------------------------------------------------------------------===// +// RecordType +//===-------------------------------------------------------------------===// - result = symbolTable.lookupSymbolIn(result, flatSymbolRef.getAttr()); +mlir::Operation * +RecordType::getRecordOp(mlir::SymbolTableCollection &symbolTable, + mlir::ModuleOp moduleOp) { + mlir::Operation *result = moduleOp.getOperation(); + result = symbolTable.lookupSymbolIn(result, getName().getRootReference()); + mlir::SymbolRefAttr nameAttr = getName(); + + for (mlir::FlatSymbolRefAttr flatSymbolRef : nameAttr.getNestedReferences()) { + if (result == nullptr) { + return nullptr; } - return result; + result = symbolTable.lookupSymbolIn(result, flatSymbolRef.getAttr()); } - //===-------------------------------------------------------------------===// - // VariableType - //===-------------------------------------------------------------------===// + return result; +} - mlir::Type VariableType::parse(mlir::AsmParser& parser) - { - llvm::SmallVector dimensions; - mlir::Type elementType; +//===-------------------------------------------------------------------===// +// VariableType +//===-------------------------------------------------------------------===// - if (parser.parseLess() || - parser.parseDimensionList(dimensions) || - parser.parseType(elementType)) { - return {}; - } +mlir::Type VariableType::parse(mlir::AsmParser &parser) { + llvm::SmallVector dimensions; + mlir::Type elementType; - VariabilityProperty variabilityProperty = VariabilityProperty::none; - IOProperty ioProperty = IOProperty::none; - - while (mlir::succeeded(parser.parseOptionalComma())) { - if (mlir::succeeded(parser.parseOptionalKeyword("discrete"))) { - variabilityProperty = VariabilityProperty::discrete; - } else if (mlir::succeeded(parser.parseOptionalKeyword("parameter"))) { - variabilityProperty = VariabilityProperty::parameter; - } else if (mlir::succeeded(parser.parseOptionalKeyword("constant"))) { - variabilityProperty = VariabilityProperty::constant; - } else if (mlir::succeeded(parser.parseOptionalKeyword("input"))) { - ioProperty = IOProperty::input; - } else if (mlir::succeeded(parser.parseOptionalKeyword("output"))) { - ioProperty = IOProperty::output; - } - } + if (parser.parseLess() || parser.parseDimensionList(dimensions) || + parser.parseType(elementType)) { + return {}; + } - if (parser.parseGreater()) { - return {}; + VariabilityProperty variabilityProperty = VariabilityProperty::none; + IOProperty ioProperty = IOProperty::none; + + while (mlir::succeeded(parser.parseOptionalComma())) { + if (mlir::succeeded(parser.parseOptionalKeyword("discrete"))) { + variabilityProperty = VariabilityProperty::discrete; + } else if (mlir::succeeded(parser.parseOptionalKeyword("parameter"))) { + variabilityProperty = VariabilityProperty::parameter; + } else if (mlir::succeeded(parser.parseOptionalKeyword("constant"))) { + variabilityProperty = VariabilityProperty::constant; + } else if (mlir::succeeded(parser.parseOptionalKeyword("input"))) { + ioProperty = IOProperty::input; + } else if (mlir::succeeded(parser.parseOptionalKeyword("output"))) { + ioProperty = IOProperty::output; } + } - return VariableType::get( - dimensions, elementType, variabilityProperty, ioProperty); + if (parser.parseGreater()) { + return {}; } - void VariableType::print(mlir::AsmPrinter& printer) const - { - printer << "<"; + return VariableType::get(dimensions, elementType, variabilityProperty, + ioProperty); +} - for (int64_t dimension : getShape()) { - if (dimension == VariableType::kDynamic) { - printer << "?"; - } else { - printer << dimension; - } +void VariableType::print(mlir::AsmPrinter &printer) const { + printer << "<"; - printer << "x"; + for (int64_t dimension : getShape()) { + if (dimension == VariableType::kDynamic) { + printer << "?"; + } else { + printer << dimension; } - printer << getElementType(); + printer << "x"; + } - if (isDiscrete()) { - printer << ", discrete"; - } else if (isParameter()) { - printer << ", parameter"; - } else if (isConstant()) { - printer << ", constant"; - } + printer << getElementType(); - if (isInput()) { - printer << ", input"; - } else if (isOutput()) { - printer << ", output"; - } + if (isDiscrete()) { + printer << ", discrete"; + } else if (isParameter()) { + printer << ", parameter"; + } else if (isConstant()) { + printer << ", constant"; + } - printer << ">"; - } - - VariableType VariableType::get( - llvm::ArrayRef shape, - mlir::Type elementType, - VariabilityProperty variabilityProperty, - IOProperty ioProperty, - mlir::Attribute memorySpace) - { - // Drop default memory space value and replace it with empty attribute. - memorySpace = skipDefaultMemorySpace(memorySpace); - - return Base::get( - elementType.getContext(), - shape, - elementType, - variabilityProperty, - ioProperty, - memorySpace); - } - - VariableType VariableType::getChecked( - llvm::function_ref emitErrorFn, - llvm::ArrayRef shape, - mlir::Type elementType, - VariabilityProperty variabilityProperty, - IOProperty ioProperty, - mlir::Attribute memorySpace) - { - // Drop default memory space value and replace it with empty attribute. - memorySpace = skipDefaultMemorySpace(memorySpace); - - return Base::getChecked( - emitErrorFn, - elementType.getContext(), - shape, - elementType, - variabilityProperty, - ioProperty, - memorySpace); - } - - mlir::LogicalResult VariableType::verify( - llvm::function_ref emitError, - llvm::ArrayRef shape, - mlir::Type elementType, - VariabilityProperty variabilityProperty, - IOProperty ioProperty, - mlir::Attribute memorySpace) - { - if (!isValidElementType(elementType)) { - return emitError() << "invalid variable element type"; - } + if (isInput()) { + printer << ", input"; + } else if (isOutput()) { + printer << ", output"; + } - // Negative sizes are not allowed. - for (int64_t size : shape) { - if (size < 0 && size != VariableType::kDynamic) { - return emitError() << "invalid variable size"; - } - } + printer << ">"; +} - if (!isSupportedMemorySpace(memorySpace)) { - return emitError() << "unsupported memory space Attribute"; - } +VariableType VariableType::get(llvm::ArrayRef shape, + mlir::Type elementType, + VariabilityProperty variabilityProperty, + IOProperty ioProperty, + mlir::Attribute memorySpace) { + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); - return mlir::success(); - } + return Base::get(elementType.getContext(), shape, elementType, + variabilityProperty, ioProperty, memorySpace); +} - bool VariableType::hasRank() const - { - return true; - } +VariableType VariableType::getChecked( + llvm::function_ref emitErrorFn, + llvm::ArrayRef shape, mlir::Type elementType, + VariabilityProperty variabilityProperty, IOProperty ioProperty, + mlir::Attribute memorySpace) { + // Drop default memory space value and replace it with empty attribute. + memorySpace = skipDefaultMemorySpace(memorySpace); + + return Base::getChecked(emitErrorFn, elementType.getContext(), shape, + elementType, variabilityProperty, ioProperty, + memorySpace); +} - mlir::ShapedType VariableType::cloneWith( - std::optional> shape, - mlir::Type elementType) const - { - VariableType::Builder builder(*shape, elementType); - builder.setVariabilityProperty(getVariabilityProperty()); - builder.setVisibilityProperty(getVisibilityProperty()); - builder.setMemorySpace(getMemorySpace()); - return builder; +mlir::LogicalResult +VariableType::verify(llvm::function_ref emitError, + llvm::ArrayRef shape, mlir::Type elementType, + VariabilityProperty variabilityProperty, + IOProperty ioProperty, mlir::Attribute memorySpace) { + if (!isValidElementType(elementType)) { + return emitError() << "invalid variable element type"; } - bool VariableType::isValidElementType(mlir::Type type) - { - return type.isIntOrIndexOrFloat() || - type.isa(); + // Negative sizes are not allowed. + for (int64_t size : shape) { + if (size < 0 && size != VariableType::kDynamic) { + return emitError() << "invalid variable size"; + } } - bool VariableType::isScalar() const - { - return getRank() == 0; + if (!isSupportedMemorySpace(memorySpace)) { + return emitError() << "unsupported memory space Attribute"; } - bool VariableType::isDiscrete() const - { - return getVariabilityProperty() == VariabilityProperty::discrete; - } + return mlir::success(); +} - bool VariableType::isParameter() const - { - return getVariabilityProperty() == VariabilityProperty::parameter; - } +bool VariableType::hasRank() const { return true; } - bool VariableType::isConstant() const - { - return getVariabilityProperty() == VariabilityProperty::constant; - } +mlir::ShapedType +VariableType::cloneWith(std::optional> shape, + mlir::Type elementType) const { + VariableType::Builder builder(*shape, elementType); + builder.setVariabilityProperty(getVariabilityProperty()); + builder.setVisibilityProperty(getVisibilityProperty()); + builder.setMemorySpace(getMemorySpace()); + return builder; +} - bool VariableType::isReadOnly() const - { - return isParameter() || isConstant(); - } +bool VariableType::isValidElementType(mlir::Type type) { + return type.isIntOrIndexOrFloat() || + type.isa(); +} - bool VariableType::isInput() const - { - return getVisibilityProperty() == IOProperty::input; - } +bool VariableType::isScalar() const { return getRank() == 0; } - bool VariableType::isOutput() const - { - return getVisibilityProperty() == IOProperty::output; - } +bool VariableType::isDiscrete() const { + return getVariabilityProperty() == VariabilityProperty::discrete; +} - VariableType VariableType::wrap( - mlir::Type type, - VariabilityProperty variabilityProperty, - IOProperty ioProperty) - { - if (auto arrayType = type.dyn_cast()) { - return VariableType::get( - arrayType.getShape(), - arrayType.getElementType(), - variabilityProperty, - ioProperty, - arrayType.getMemorySpace()); - } +bool VariableType::isParameter() const { + return getVariabilityProperty() == VariabilityProperty::parameter; +} - return VariableType::get( - std::nullopt, type, variabilityProperty, ioProperty); - } +bool VariableType::isConstant() const { + return getVariabilityProperty() == VariabilityProperty::constant; +} - ArrayType VariableType::toArrayType() const - { - return ArrayType::get(getShape(), getElementType(), getMemorySpace()); - } +bool VariableType::isReadOnly() const { return isParameter() || isConstant(); } - mlir::TensorType VariableType::toTensorType() const - { - return mlir::RankedTensorType::get(getShape(), getElementType()); - } +bool VariableType::isInput() const { + return getVisibilityProperty() == IOProperty::input; +} - mlir::Type VariableType::unwrap() const - { - if (!isScalar()) { - return toTensorType(); - } +bool VariableType::isOutput() const { + return getVisibilityProperty() == IOProperty::output; +} - return getElementType(); +VariableType VariableType::wrap(mlir::Type type, + VariabilityProperty variabilityProperty, + IOProperty ioProperty) { + if (auto arrayType = type.dyn_cast()) { + return VariableType::get(arrayType.getShape(), arrayType.getElementType(), + variabilityProperty, ioProperty, + arrayType.getMemorySpace()); } - VariableType VariableType::withShape(llvm::ArrayRef shape) const - { - return VariableType::get( - shape, - getElementType(), - getVariabilityProperty(), - getVisibilityProperty()); - } + return VariableType::get(std::nullopt, type, variabilityProperty, ioProperty); +} - VariableType VariableType::withType(mlir::Type type) const - { - return VariableType::get( - getShape(), - type, - getVariabilityProperty(), - getVisibilityProperty()); - } +ArrayType VariableType::toArrayType() const { + return ArrayType::get(getShape(), getElementType(), getMemorySpace()); +} - VariableType VariableType::withVariabilityProperty( - VariabilityProperty variabilityProperty) const - { - return VariableType::get( - getShape(), - getElementType(), - variabilityProperty, - getVisibilityProperty()); - } +mlir::TensorType VariableType::toTensorType() const { + return mlir::RankedTensorType::get(getShape(), getElementType()); +} - VariableType VariableType::withoutVariabilityProperty() const - { - return withVariabilityProperty(VariabilityProperty::none); +mlir::Type VariableType::unwrap() const { + if (!isScalar()) { + return toTensorType(); } - VariableType VariableType::asDiscrete() const - { - return withVariabilityProperty(VariabilityProperty::discrete); - } + return getElementType(); +} - VariableType VariableType::asParameter() const - { - return withVariabilityProperty(VariabilityProperty::parameter); - } +VariableType VariableType::withShape(llvm::ArrayRef shape) const { + return VariableType::get(shape, getElementType(), getVariabilityProperty(), + getVisibilityProperty()); +} - VariableType VariableType::asConstant() const - { - return withVariabilityProperty(VariabilityProperty::constant); - } +VariableType VariableType::withType(mlir::Type type) const { + return VariableType::get(getShape(), type, getVariabilityProperty(), + getVisibilityProperty()); +} - VariableType VariableType::withIOProperty(IOProperty ioProperty) const - { - return VariableType::get( - getShape(), - getElementType(), - getVariabilityProperty(), - ioProperty); - } +VariableType VariableType::withVariabilityProperty( + VariabilityProperty variabilityProperty) const { + return VariableType::get(getShape(), getElementType(), variabilityProperty, + getVisibilityProperty()); +} - VariableType VariableType::withoutIOProperty() const - { - return withIOProperty(IOProperty::none); - } +VariableType VariableType::withoutVariabilityProperty() const { + return withVariabilityProperty(VariabilityProperty::none); +} - VariableType VariableType::asInput() const - { - return withIOProperty(IOProperty::input); - } +VariableType VariableType::asDiscrete() const { + return withVariabilityProperty(VariabilityProperty::discrete); +} - VariableType VariableType::asOutput() const - { - return withIOProperty(IOProperty::output); - } +VariableType VariableType::asParameter() const { + return withVariabilityProperty(VariabilityProperty::parameter); } -namespace mlir::bmodelica::detail -{ - bool isSupportedMemorySpace(mlir::Attribute memorySpace) - { - // Empty attribute is allowed as default memory space. - if (!memorySpace) { - return true; - } +VariableType VariableType::asConstant() const { + return withVariabilityProperty(VariabilityProperty::constant); +} - // Supported built-in attributes. - if (memorySpace.isa< - mlir::IntegerAttr, mlir::StringAttr, mlir::DictionaryAttr>()) { - return true; - } +VariableType VariableType::withIOProperty(IOProperty ioProperty) const { + return VariableType::get(getShape(), getElementType(), + getVariabilityProperty(), ioProperty); +} - // Allow custom dialect attributes. - if (!isa(memorySpace.getDialect())) { - return true; - } +VariableType VariableType::withoutIOProperty() const { + return withIOProperty(IOProperty::none); +} + +VariableType VariableType::asInput() const { + return withIOProperty(IOProperty::input); +} + +VariableType VariableType::asOutput() const { + return withIOProperty(IOProperty::output); +} +} // namespace mlir::bmodelica + +namespace mlir::bmodelica::detail { +bool isSupportedMemorySpace(mlir::Attribute memorySpace) { + // Empty attribute is allowed as default memory space. + if (!memorySpace) { + return true; + } - return false; + // Supported built-in attributes. + if (memorySpace + .isa()) { + return true; } - mlir::Attribute skipDefaultMemorySpace(mlir::Attribute memorySpace) - { - mlir::IntegerAttr intMemorySpace = - memorySpace.dyn_cast_or_null(); + // Allow custom dialect attributes. + if (!isa(memorySpace.getDialect())) { + return true; + } - if (intMemorySpace && intMemorySpace.getValue() == 0) { - return nullptr; - } + return false; +} - return memorySpace; +mlir::Attribute skipDefaultMemorySpace(mlir::Attribute memorySpace) { + mlir::IntegerAttr intMemorySpace = + memorySpace.dyn_cast_or_null(); + + if (intMemorySpace && intMemorySpace.getValue() == 0) { + return nullptr; } + + return memorySpace; } +} // namespace mlir::bmodelica::detail diff --git a/lib/Dialect/BaseModelica/Transforms/EulerForward.cpp b/lib/Dialect/BaseModelica/Transforms/EulerForward.cpp index 4d5d65c23..41eb73d39 100644 --- a/lib/Dialect/BaseModelica/Transforms/EulerForward.cpp +++ b/lib/Dialect/BaseModelica/Transforms/EulerForward.cpp @@ -7,75 +7,60 @@ #define DEBUG_TYPE "euler-forward" -namespace mlir::bmodelica -{ +namespace mlir::bmodelica { #define GEN_PASS_DEF_EULERFORWARDPASS #include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" -} +} // namespace mlir::bmodelica using namespace ::mlir::bmodelica; -namespace -{ - class EulerForwardPass - : public mlir::bmodelica::impl::EulerForwardPassBase - { - public: - using EulerForwardPassBase::EulerForwardPassBase; - - void runOnOperation() override; - - private: - mlir::LogicalResult processModelOp(ModelOp modelOp); - - mlir::LogicalResult solveMainModel( - mlir::IRRewriter& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef variables, - llvm::ArrayRef SCCs); - - mlir::LogicalResult createUpdateNonStateVariablesFunction( - mlir::IRRewriter& rewriter, - mlir::ModuleOp moduleOp, - ModelOp modelOp, - llvm::ArrayRef SCCs); - - mlir::LogicalResult createUpdateStateVariablesFunction( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - ModelOp modelOp, - llvm::ArrayRef variableOps); - - mlir::LogicalResult createRangedStateVariableUpdateBlocks( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - DynamicOp dynamicOp, - VariableOp stateVariable, - VariableOp derivativeVariable, - GlobalVariableOp timeStepVariable); - - mlir::LogicalResult createMonolithicStateVariableUpdateBlock( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - DynamicOp dynamicOp, - VariableOp stateVariable, - VariableOp derivativeVariable, - GlobalVariableOp timeStepVariable); - - mlir::LogicalResult cleanModelOp(ModelOp modelOp); - }; -} - -void EulerForwardPass::runOnOperation() -{ - mlir::ModuleOp moduleOp = getOperation(); +namespace { +class EulerForwardPass + : public mlir::bmodelica::impl::EulerForwardPassBase { +public: + using EulerForwardPassBase::EulerForwardPassBase; + + void runOnOperation() override; + +private: + mlir::LogicalResult processModelOp(ModelOp modelOp); + + mlir::LogicalResult + solveMainModel(mlir::IRRewriter &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + ModelOp modelOp, llvm::ArrayRef variables, + llvm::ArrayRef SCCs); + + mlir::LogicalResult createUpdateNonStateVariablesFunction( + mlir::IRRewriter &rewriter, mlir::ModuleOp moduleOp, ModelOp modelOp, + llvm::ArrayRef SCCs); + + mlir::LogicalResult createUpdateStateVariablesFunction( + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, + mlir::ModuleOp moduleOp, ModelOp modelOp, + llvm::ArrayRef variableOps); + + mlir::LogicalResult createRangedStateVariableUpdateBlocks( + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, + mlir::ModuleOp moduleOp, DynamicOp dynamicOp, VariableOp stateVariable, + VariableOp derivativeVariable, GlobalVariableOp timeStepVariable); + + mlir::LogicalResult createMonolithicStateVariableUpdateBlock( + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, + mlir::ModuleOp moduleOp, DynamicOp dynamicOp, VariableOp stateVariable, + VariableOp derivativeVariable, GlobalVariableOp timeStepVariable); + + mlir::LogicalResult cleanModelOp(ModelOp modelOp); +}; +} // namespace + +void EulerForwardPass::runOnOperation() { llvm::SmallVector modelOps; - walkClasses(getOperation(), [&](mlir::Operation* op) { + walkClasses(getOperation(), [&](mlir::Operation *op) { if (auto modelOp = mlir::dyn_cast(op)) { modelOps.push_back(modelOp); } @@ -92,8 +77,7 @@ void EulerForwardPass::runOnOperation() } } -mlir::LogicalResult EulerForwardPass::processModelOp(ModelOp modelOp) -{ +mlir::LogicalResult EulerForwardPass::processModelOp(ModelOp modelOp) { mlir::IRRewriter rewriter(&getContext()); mlir::SymbolTableCollection symbolTableCollection; @@ -104,8 +88,8 @@ mlir::LogicalResult EulerForwardPass::processModelOp(ModelOp modelOp) modelOp.collectVariables(variables); // Solve the 'main' model. - if (mlir::failed(solveMainModel( - rewriter, symbolTableCollection, modelOp, variables, mainSCCs))) { + if (mlir::failed(solveMainModel(rewriter, symbolTableCollection, modelOp, + variables, mainSCCs))) { return mlir::failure(); } @@ -113,16 +97,13 @@ mlir::LogicalResult EulerForwardPass::processModelOp(ModelOp modelOp) } mlir::LogicalResult EulerForwardPass::solveMainModel( - mlir::IRRewriter& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef variables, - llvm::ArrayRef SCCs) -{ + mlir::IRRewriter &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + llvm::ArrayRef variables, llvm::ArrayRef SCCs) { auto moduleOp = modelOp->getParentOfType(); - if (mlir::failed(createUpdateNonStateVariablesFunction( - rewriter, moduleOp, modelOp, SCCs))) { + if (mlir::failed(createUpdateNonStateVariablesFunction(rewriter, moduleOp, + modelOp, SCCs))) { return mlir::failure(); } @@ -135,11 +116,8 @@ mlir::LogicalResult EulerForwardPass::solveMainModel( } mlir::LogicalResult EulerForwardPass::createUpdateNonStateVariablesFunction( - mlir::IRRewriter& rewriter, - mlir::ModuleOp moduleOp, - ModelOp modelOp, - llvm::ArrayRef SCCs) -{ + mlir::IRRewriter &rewriter, mlir::ModuleOp moduleOp, ModelOp modelOp, + llvm::ArrayRef SCCs) { mlir::OpBuilder::InsertionGuard guard(rewriter); // Create the function running the schedule. @@ -149,7 +127,7 @@ mlir::LogicalResult EulerForwardPass::createUpdateNonStateVariablesFunction( modelOp.getLoc(), "updateNonStateVariables", rewriter.getFunctionType(std::nullopt, std::nullopt)); - mlir::Block* entryBlock = functionOp.addEntryBlock(); + mlir::Block *entryBlock = functionOp.addEntryBlock(); rewriter.setInsertionPointToStart(entryBlock); if (!SCCs.empty()) { @@ -179,12 +157,9 @@ mlir::LogicalResult EulerForwardPass::createUpdateNonStateVariablesFunction( } mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - ModelOp modelOp, - llvm::ArrayRef variableOps) -{ + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, mlir::ModuleOp moduleOp, + ModelOp modelOp, llvm::ArrayRef variableOps) { mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(moduleOp.getBody()); @@ -200,21 +175,20 @@ mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( modelOp.getLoc(), "updateStateVariables", builder.getFunctionType(builder.getF64Type(), std::nullopt)); - mlir::Block* functionBody = functionOp.addEntryBlock(); + mlir::Block *functionBody = functionOp.addEntryBlock(); builder.setInsertionPointToStart(functionBody); mlir::Value timeStepArg = functionOp.getArgument(0); // Set the time step in the global variable. - auto timeStepArray = builder.create( - timeStepArg.getLoc(), timeStepVariable); + auto timeStepArray = builder.create(timeStepArg.getLoc(), + timeStepVariable); - builder.create( - timeStepArg.getLoc(), timeStepArg, timeStepArray, std::nullopt); + builder.create(timeStepArg.getLoc(), timeStepArg, timeStepArray, + std::nullopt); // Compute the list of state and derivative variables. - const DerivativesMap& derivativesMap = - modelOp.getProperties().derivativesMap; + const DerivativesMap &derivativesMap = modelOp.getProperties().derivativesMap; // The two lists are kept in sync. llvm::SmallVector stateVariables; @@ -222,10 +196,10 @@ mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( for (VariableOp variableOp : variableOps) { if (auto derivativeName = derivativesMap.getDerivative( - mlir::FlatSymbolRefAttr::get(variableOp.getSymNameAttr()))) { + mlir::FlatSymbolRefAttr::get(variableOp.getSymNameAttr()))) { auto derivativeVariableOp = - symbolTableCollection.lookupSymbolIn( - modelOp, *derivativeName); + symbolTableCollection.lookupSymbolIn(modelOp, + *derivativeName); stateVariables.push_back(variableOp); derivativeVariables.push_back(derivativeVariableOp); @@ -238,8 +212,8 @@ mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( // Create the schedule. builder.setInsertionPointToEnd(modelOp.getBody()); - auto scheduleOp = builder.create( - modelOp.getLoc(), "schedule_state_variables"); + auto scheduleOp = builder.create(modelOp.getLoc(), + "schedule_state_variables"); symbolTableCollection.getSymbolTable(modelOp).insert(scheduleOp); @@ -250,7 +224,7 @@ mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( builder.createBlock(&dynamicOp.getBodyRegion()); builder.setInsertionPointToStart(dynamicOp.getBody()); - for (const auto& [stateVariable, derivativeVariable] : + for (const auto &[stateVariable, derivativeVariable] : llvm::zip(stateVariables, derivativeVariables)) { if (rangedStateUpdateFunctions) { if (mlir::failed(createRangedStateVariableUpdateBlocks( @@ -279,13 +253,10 @@ mlir::LogicalResult EulerForwardPass::createUpdateStateVariablesFunction( } static void getStateUpdateBlockVarWriteReadInfo( - mlir::OpBuilder& builder, - VariableOp stateVariable, - VariableOp derivativeVariable, - MultidimensionalRangeAttr ranges, - llvm::SmallVectorImpl& writtenVariables, - llvm::SmallVectorImpl& readVariables) -{ + mlir::OpBuilder &builder, VariableOp stateVariable, + VariableOp derivativeVariable, MultidimensionalRangeAttr ranges, + llvm::SmallVectorImpl &writtenVariables, + llvm::SmallVectorImpl &readVariables) { IndexSet indices; if (ranges) { @@ -293,46 +264,36 @@ static void getStateUpdateBlockVarWriteReadInfo( } writtenVariables.emplace_back( - mlir::SymbolRefAttr::get(stateVariable.getSymNameAttr()), - indices); + mlir::SymbolRefAttr::get(stateVariable.getSymNameAttr()), indices); readVariables.emplace_back( - mlir::SymbolRefAttr::get(derivativeVariable.getSymNameAttr()), - indices); + mlir::SymbolRefAttr::get(derivativeVariable.getSymNameAttr()), indices); } -static void createStateUpdateFunctionCall( - mlir::OpBuilder& builder, - VariableOp stateVariable, - VariableOp derivativeVariable, - MultidimensionalRangeAttr ranges, - EquationFunctionOp equationFuncOp) -{ - auto blockOp = builder.create( - stateVariable.getLoc(), true); - - getStateUpdateBlockVarWriteReadInfo( - builder, stateVariable, derivativeVariable, ranges, - blockOp.getProperties().writtenVariables, - blockOp.getProperties().readVariables); +static void createStateUpdateFunctionCall(mlir::OpBuilder &builder, + VariableOp stateVariable, + VariableOp derivativeVariable, + MultidimensionalRangeAttr ranges, + EquationFunctionOp equationFuncOp) { + auto blockOp = builder.create(stateVariable.getLoc(), true); + + getStateUpdateBlockVarWriteReadInfo(builder, stateVariable, + derivativeVariable, ranges, + blockOp.getProperties().writtenVariables, + blockOp.getProperties().readVariables); builder.createBlock(&blockOp.getBodyRegion()); builder.setInsertionPointToStart(blockOp.getBody()); - builder.create( - equationFuncOp.getLoc(), equationFuncOp.getSymName(), - ranges, true); + builder.create(equationFuncOp.getLoc(), + equationFuncOp.getSymName(), ranges, true); } mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - DynamicOp dynamicOp, - VariableOp stateVariable, - VariableOp derivativeVariable, - GlobalVariableOp timeStepVariable) -{ + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, mlir::ModuleOp moduleOp, + DynamicOp dynamicOp, VariableOp stateVariable, + VariableOp derivativeVariable, GlobalVariableOp timeStepVariable) { mlir::OpBuilder::InsertionGuard guard(builder); // Create the equation function. @@ -342,30 +303,26 @@ mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( auto equationFuncOp = builder.create( stateVariable.getLoc(), - "euler_state_update_" + stateVariable.getSymName().str(), - variableRank); + "euler_state_update_" + stateVariable.getSymName().str(), variableRank); symbolTableCollection.getSymbolTable(moduleOp).insert(equationFuncOp); - mlir::Block* equationFuncBody = equationFuncOp.addEntryBlock(); + mlir::Block *equationFuncBody = equationFuncOp.addEntryBlock(); builder.setInsertionPointToStart(equationFuncBody); mlir::Value timeStep = builder.create( stateVariable.getLoc(), timeStepVariable); - timeStep = builder.create( - timeStep.getLoc(), timeStep, std::nullopt); + timeStep = builder.create(timeStep.getLoc(), timeStep, std::nullopt); - auto getNewScalarStateFn = - [&](mlir::OpBuilder& nestedBuilder, - mlir::Location nestedLoc, - mlir::Value scalarState, - mlir::Value scalarDerivative) -> mlir::Value { + auto getNewScalarStateFn = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::Value scalarState, + mlir::Value scalarDerivative) -> mlir::Value { mlir::Value result = nestedBuilder.create( - nestedLoc, scalarDerivative.getType(), scalarDerivative, - timeStep); + nestedLoc, scalarDerivative.getType(), scalarDerivative, timeStep); - result = nestedBuilder.create( - nestedLoc, scalarState.getType(), scalarState, result); + result = nestedBuilder.create(nestedLoc, scalarState.getType(), + scalarState, result); return result; }; @@ -381,13 +338,12 @@ mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( mlir::Value updatedState = getNewScalarStateFn( builder, equationFuncOp.getLoc(), state, derivative); - builder.create( - equationFuncOp.getLoc(), stateVariable, updatedState); + builder.create(equationFuncOp.getLoc(), + stateVariable, updatedState); } else { // Array variable. mlir::Value state = builder.create( - equationFuncOp.getLoc(), - stateVariable.getVariableType().toArrayType(), + equationFuncOp.getLoc(), stateVariable.getVariableType().toArrayType(), getSymbolRefFromRoot(stateVariable)); mlir::Value derivative = builder.create( @@ -406,19 +362,19 @@ mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( mlir::affine::buildAffineLoopNest( builder, equationFuncOp.getLoc(), lowerBounds, upperBounds, steps, - [&](mlir::OpBuilder& nestedBuilder, mlir::Location nestedLoc, + [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange indices) { - mlir::Value scalarState = nestedBuilder.create( - nestedLoc, state, indices); + mlir::Value scalarState = + nestedBuilder.create(nestedLoc, state, indices); - mlir::Value scalarDerivative = nestedBuilder.create( - nestedLoc, derivative, indices); + mlir::Value scalarDerivative = + nestedBuilder.create(nestedLoc, derivative, indices); mlir::Value updatedScalarState = getNewScalarStateFn( nestedBuilder, nestedLoc, scalarState, scalarDerivative); - nestedBuilder.create( - nestedLoc, updatedScalarState, state, indices); + nestedBuilder.create(nestedLoc, updatedScalarState, state, + indices); }); } @@ -428,21 +384,18 @@ mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( // Create the schedule blocks and the calls to the equation function. builder.setInsertionPointToEnd(dynamicOp.getBody()); - IndexSet indices = - stateVariable.getIndices().getCanonicalRepresentation(); + IndexSet indices = stateVariable.getIndices().getCanonicalRepresentation(); if (indices.empty()) { - createStateUpdateFunctionCall( - builder, stateVariable, derivativeVariable, nullptr, - equationFuncOp); + createStateUpdateFunctionCall(builder, stateVariable, derivativeVariable, + nullptr, equationFuncOp); } else { - for (const MultidimensionalRange& range : llvm::make_range( - indices.rangesBegin(), indices.rangesEnd())) { + for (const MultidimensionalRange &range : + llvm::make_range(indices.rangesBegin(), indices.rangesEnd())) { createStateUpdateFunctionCall( builder, stateVariable, derivativeVariable, - MultidimensionalRangeAttr::get(&getContext(), range), - equationFuncOp); + MultidimensionalRangeAttr::get(&getContext(), range), equationFuncOp); } } @@ -450,14 +403,10 @@ mlir::LogicalResult EulerForwardPass::createRangedStateVariableUpdateBlocks( } mlir::LogicalResult EulerForwardPass::createMonolithicStateVariableUpdateBlock( - mlir::OpBuilder& builder, - mlir::SymbolTableCollection& symbolTableCollection, - mlir::ModuleOp moduleOp, - DynamicOp dynamicOp, - VariableOp stateVariable, - VariableOp derivativeVariable, - GlobalVariableOp timeStepVariable) -{ + mlir::OpBuilder &builder, + mlir::SymbolTableCollection &symbolTableCollection, mlir::ModuleOp moduleOp, + DynamicOp dynamicOp, VariableOp stateVariable, + VariableOp derivativeVariable, GlobalVariableOp timeStepVariable) { mlir::OpBuilder::InsertionGuard guard(builder); // Create the equation function. @@ -471,29 +420,26 @@ mlir::LogicalResult EulerForwardPass::createMonolithicStateVariableUpdateBlock( symbolTableCollection.getSymbolTable(moduleOp).insert(funcOp); - mlir::Block* funcBody = funcOp.addEntryBlock(); + mlir::Block *funcBody = funcOp.addEntryBlock(); builder.setInsertionPointToStart(funcBody); mlir::Value timeStep = builder.create( stateVariable.getLoc(), timeStepVariable); - timeStep = builder.create( - timeStep.getLoc(), timeStep, std::nullopt); + timeStep = builder.create(timeStep.getLoc(), timeStep, std::nullopt); - mlir::Value state = builder.create( - funcOp.getLoc(), stateVariable); + mlir::Value state = + builder.create(funcOp.getLoc(), stateVariable); mlir::Value derivative = builder.create( funcOp.getLoc(), derivativeVariable); - mlir::Value mulOp = builder.create( - funcOp.getLoc(), timeStep, derivative); + mlir::Value mulOp = + builder.create(funcOp.getLoc(), timeStep, derivative); - mlir::Value addOp = builder.create( - funcOp.getLoc(), state, mulOp); + mlir::Value addOp = builder.create(funcOp.getLoc(), state, mulOp); - builder.create( - funcOp.getLoc(), stateVariable, addOp); + builder.create(funcOp.getLoc(), stateVariable, addOp); builder.setInsertionPointToEnd(funcBody); builder.create(funcOp.getLoc()); @@ -503,8 +449,7 @@ mlir::LogicalResult EulerForwardPass::createMonolithicStateVariableUpdateBlock( int64_t variableRank = variableType.getRank(); - auto blockOp = builder.create( - stateVariable.getLoc(), true); + auto blockOp = builder.create(stateVariable.getLoc(), true); if (variableRank == 0) { getStateUpdateBlockVarWriteReadInfo( @@ -522,8 +467,8 @@ mlir::LogicalResult EulerForwardPass::createMonolithicStateVariableUpdateBlock( getStateUpdateBlockVarWriteReadInfo( builder, stateVariable, derivativeVariable, - MultidimensionalRangeAttr::get( - builder.getContext(), MultidimensionalRange(ranges)), + MultidimensionalRangeAttr::get(builder.getContext(), + MultidimensionalRange(ranges)), blockOp.getProperties().writtenVariables, blockOp.getProperties().readVariables); } @@ -536,23 +481,19 @@ mlir::LogicalResult EulerForwardPass::createMonolithicStateVariableUpdateBlock( return mlir::success(); } -mlir::LogicalResult EulerForwardPass::cleanModelOp(ModelOp modelOp) -{ +mlir::LogicalResult EulerForwardPass::cleanModelOp(ModelOp modelOp) { mlir::RewritePatternSet patterns(&getContext()); ModelOp::getCleaningPatterns(patterns, &getContext()); return mlir::applyPatternsAndFoldGreedily(modelOp, std::move(patterns)); } -namespace mlir::bmodelica -{ - std::unique_ptr createEulerForwardPass() - { - return std::make_unique(); - } +namespace mlir::bmodelica { +std::unique_ptr createEulerForwardPass() { + return std::make_unique(); +} - std::unique_ptr createEulerForwardPass( - const EulerForwardPassOptions& options) - { - return std::make_unique(options); - } +std::unique_ptr +createEulerForwardPass(const EulerForwardPassOptions &options) { + return std::make_unique(options); } +} // namespace mlir::bmodelica diff --git a/lib/Dialect/BaseModelica/Transforms/FunctionInlining.cpp b/lib/Dialect/BaseModelica/Transforms/FunctionInlining.cpp index 292fcde7d..123d97c24 100644 --- a/lib/Dialect/BaseModelica/Transforms/FunctionInlining.cpp +++ b/lib/Dialect/BaseModelica/Transforms/FunctionInlining.cpp @@ -7,262 +7,212 @@ #include "llvm/ADT/SCCIterator.h" #include -namespace mlir::bmodelica -{ +namespace mlir::bmodelica { #define GEN_PASS_DEF_FUNCTIONINLININGPASS #include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" -} +} // namespace mlir::bmodelica using namespace ::mlir::bmodelica; -namespace -{ - class CallGraph - { - public: - /// A node of the graph. - struct Node - { - Node() : graph(nullptr), op(nullptr) - { - } +namespace { +class CallGraph { +public: + /// A node of the graph. + struct Node { + Node() : graph(nullptr), op(nullptr) {} - Node(const CallGraph* graph, mlir::Operation* op) - : graph(graph), op(op) - { - } + Node(const CallGraph *graph, mlir::Operation *op) : graph(graph), op(op) {} - bool operator==(const Node& other) const - { - return graph == other.graph && op == other.op; - } + bool operator==(const Node &other) const { + return graph == other.graph && op == other.op; + } - bool operator!=(const Node& other) const - { - return graph != other.graph || op != other.op; - } + bool operator!=(const Node &other) const { + return graph != other.graph || op != other.op; + } - bool operator<(const Node& other) const - { - if (op == nullptr) { - return true; - } + bool operator<(const Node &other) const { + if (op == nullptr) { + return true; + } - if (other.op == nullptr) { - return false; - } + if (other.op == nullptr) { + return false; + } - return op < other.op; - } + return op < other.op; + } - const CallGraph* graph; - mlir::Operation* op; - }; + const CallGraph *graph; + mlir::Operation *op; + }; - CallGraph() - { - // Entry node, which is connected to every other node. - nodes.emplace_back(this, nullptr); + CallGraph() { + // Entry node, which is connected to every other node. + nodes.emplace_back(this, nullptr); - // Ensure that the set of children for the entry node exists, even in - // case of no other nodes. - arcs[getEntryNode().op] = {}; - } + // Ensure that the set of children for the entry node exists, even in + // case of no other nodes. + arcs[getEntryNode().op] = {}; + } - ~CallGraph() = default; + ~CallGraph() = default; - void addNode(FunctionOp functionOp) - { - assert(functionOp != nullptr); - Node node(this, functionOp.getOperation()); - nodes.push_back(node); - nodesByOp[functionOp.getOperation()] = node; + void addNode(FunctionOp functionOp) { + assert(functionOp != nullptr); + Node node(this, functionOp.getOperation()); + nodes.push_back(node); + nodesByOp[functionOp.getOperation()] = node; - // Ensure that the set of children for the node exists. - arcs[node.op] = {}; + // Ensure that the set of children for the node exists. + arcs[node.op] = {}; - // Connect the entry node. - arcs[getEntryNode().op].insert(node); - } + // Connect the entry node. + arcs[getEntryNode().op].insert(node); + } - bool hasNode(FunctionOp functionOp) const - { - assert(functionOp != nullptr); - return nodesByOp.find(functionOp.getOperation()) != nodesByOp.end(); - } + bool hasNode(FunctionOp functionOp) const { + assert(functionOp != nullptr); + return nodesByOp.find(functionOp.getOperation()) != nodesByOp.end(); + } - void addEdge(FunctionOp caller, FunctionOp callee) - { - assert(caller != nullptr); - assert(callee != nullptr); - assert(arcs.find(caller.getOperation()) != arcs.end()); - assert(nodesByOp.find(callee.getOperation()) != nodesByOp.end()); - arcs[caller.getOperation()].insert(nodesByOp[callee.getOperation()]); - } + void addEdge(FunctionOp caller, FunctionOp callee) { + assert(caller != nullptr); + assert(callee != nullptr); + assert(arcs.find(caller.getOperation()) != arcs.end()); + assert(nodesByOp.find(callee.getOperation()) != nodesByOp.end()); + arcs[caller.getOperation()].insert(nodesByOp[callee.getOperation()]); + } - /// Get the number of nodes. - /// Note that nodes consists in the inserted operations together with the - /// entry node. - size_t getNumOfNodes() const - { - return nodes.size(); - } + /// Get the number of nodes. + /// Note that nodes consists in the inserted operations together with the + /// entry node. + size_t getNumOfNodes() const { return nodes.size(); } - /// Get the entry node of the graph. - const Node& getEntryNode() const - { - return nodes[0]; - } + /// Get the entry node of the graph. + const Node &getEntryNode() const { return nodes[0]; } - /// @name Iterators for the children of a node - /// { + /// @name Iterators for the children of a node + /// { - std::set::const_iterator childrenBegin(Node node) const - { - auto it = arcs.find(node.op); - assert(it != arcs.end()); - const auto& children = it->second; - return children.begin(); - } + std::set::const_iterator childrenBegin(Node node) const { + auto it = arcs.find(node.op); + assert(it != arcs.end()); + const auto &children = it->second; + return children.begin(); + } - std::set::const_iterator childrenEnd(Node node) const - { - auto it = arcs.find(node.op); - assert(it != arcs.end()); - const auto& children = it->second; - return children.end(); - } + std::set::const_iterator childrenEnd(Node node) const { + auto it = arcs.find(node.op); + assert(it != arcs.end()); + const auto &children = it->second; + return children.end(); + } - /// } - /// @name Iterators for the nodes - /// { + /// } + /// @name Iterators for the nodes + /// { - llvm::SmallVector::const_iterator nodesBegin() const - { - return nodes.begin(); - } + llvm::SmallVector::const_iterator nodesBegin() const { + return nodes.begin(); + } - llvm::SmallVector::const_iterator nodesEnd() const - { - return nodes.end(); - } + llvm::SmallVector::const_iterator nodesEnd() const { + return nodes.end(); + } - /// } + /// } - llvm::DenseSet getInlinableFunctions() const; + llvm::DenseSet getInlinableFunctions() const; - private: - llvm::DenseMap> arcs; - llvm::SmallVector nodes; - llvm::DenseMap nodesByOp; - }; -} +private: + llvm::DenseMap> arcs; + llvm::SmallVector nodes; + llvm::DenseMap nodesByOp; +}; +} // namespace -namespace llvm -{ - template<> - struct DenseMapInfo<::CallGraph::Node> - { - static inline ::CallGraph::Node getEmptyKey() - { - return {nullptr, nullptr}; - } +namespace llvm { +template <> +struct DenseMapInfo<::CallGraph::Node> { + static inline ::CallGraph::Node getEmptyKey() { return {nullptr, nullptr}; } - static inline ::CallGraph::Node getTombstoneKey() - { - return {nullptr, nullptr}; - } + static inline ::CallGraph::Node getTombstoneKey() { + return {nullptr, nullptr}; + } - static unsigned getHashValue(const ::CallGraph::Node& val) - { - return std::hash{}(val.op); - } + static unsigned getHashValue(const ::CallGraph::Node &val) { + return std::hash{}(val.op); + } - static bool isEqual( - const ::CallGraph::Node& lhs, - const ::CallGraph::Node& rhs) - { - return lhs.graph == rhs.graph && lhs.op == rhs.op; - } - }; + static bool isEqual(const ::CallGraph::Node &lhs, + const ::CallGraph::Node &rhs) { + return lhs.graph == rhs.graph && lhs.op == rhs.op; + } +}; - template<> - struct GraphTraits - { - using GraphType = const ::CallGraph*; - using NodeRef = ::CallGraph::Node; +template <> +struct GraphTraits { + using GraphType = const ::CallGraph *; + using NodeRef = ::CallGraph::Node; - using ChildIteratorType = - std::set<::CallGraph::Node>::const_iterator; + using ChildIteratorType = std::set<::CallGraph::Node>::const_iterator; - static NodeRef getEntryNode(const GraphType& graph) - { - return graph->getEntryNode(); - } + static NodeRef getEntryNode(const GraphType &graph) { + return graph->getEntryNode(); + } - static ChildIteratorType child_begin(NodeRef node) - { - return node.graph->childrenBegin(node); - } + static ChildIteratorType child_begin(NodeRef node) { + return node.graph->childrenBegin(node); + } - static ChildIteratorType child_end(NodeRef node) - { - return node.graph->childrenEnd(node); - } + static ChildIteratorType child_end(NodeRef node) { + return node.graph->childrenEnd(node); + } - using nodes_iterator = llvm::SmallVector::const_iterator; + using nodes_iterator = llvm::SmallVector::const_iterator; - static nodes_iterator nodes_begin(GraphType* graph) - { - return (*graph)->nodesBegin(); - } + static nodes_iterator nodes_begin(GraphType *graph) { + return (*graph)->nodesBegin(); + } - static nodes_iterator nodes_end(GraphType* graph) - { - return (*graph)->nodesEnd(); - } + static nodes_iterator nodes_end(GraphType *graph) { + return (*graph)->nodesEnd(); + } - // There is no need for a dedicated class for the arcs. - using EdgeRef = ::CallGraph::Node; + // There is no need for a dedicated class for the arcs. + using EdgeRef = ::CallGraph::Node; - using ChildEdgeIteratorType = - std::set<::CallGraph::Node>::const_iterator; + using ChildEdgeIteratorType = std::set<::CallGraph::Node>::const_iterator; - static ChildEdgeIteratorType child_edge_begin(NodeRef node) - { - return node.graph->childrenBegin(node); - } + static ChildEdgeIteratorType child_edge_begin(NodeRef node) { + return node.graph->childrenBegin(node); + } - static ChildEdgeIteratorType child_edge_end(NodeRef node) - { - return node.graph->childrenEnd(node); - } + static ChildEdgeIteratorType child_edge_end(NodeRef node) { + return node.graph->childrenEnd(node); + } - static NodeRef edge_dest(EdgeRef edge) - { - return edge; - } + static NodeRef edge_dest(EdgeRef edge) { return edge; } - static unsigned int size(GraphType* graph) - { - return (*graph)->getNumOfNodes(); - } - }; -} + static unsigned int size(GraphType *graph) { + return (*graph)->getNumOfNodes(); + } +}; +} // namespace llvm -llvm::DenseSet CallGraph::getInlinableFunctions() const -{ +llvm::DenseSet CallGraph::getInlinableFunctions() const { llvm::DenseSet result; auto beginIt = llvm::scc_begin(this); - auto endIt = llvm::scc_end(this); + auto endIt = llvm::scc_end(this); for (auto it = beginIt; it != endIt; ++it) { if (it.hasCycle()) { continue; } - for (const Node& node : *it) { + for (const Node &node : *it) { if (node != getEntryNode()) { result.insert(mlir::cast(node.op)); } @@ -272,8 +222,7 @@ llvm::DenseSet CallGraph::getInlinableFunctions() const return result; } -static bool canBeInlined(FunctionOp functionOp) -{ +static bool canBeInlined(FunctionOp functionOp) { // The function must be explicitly marked as inlinable. if (!functionOp.shouldBeInlined()) { return false; @@ -287,12 +236,12 @@ static bool canBeInlined(FunctionOp functionOp) // Check that operations inside the algorithm section have no regions with // side effects. - llvm::SmallVector nestedOps; + llvm::SmallVector nestedOps; for (AlgorithmOp algorithmOp : functionOp.getOps()) { - for (auto& nestedOp : algorithmOp.getOps()) { - for (auto& nestedRegion : nestedOp.getRegions()) { - for (auto& nestedRegionOp : nestedRegion.getOps()) { + for (auto &nestedOp : algorithmOp.getOps()) { + for (auto &nestedRegion : nestedOp.getRegions()) { + for (auto &nestedRegionOp : nestedRegion.getOps()) { nestedOps.push_back(&nestedRegionOp); } } @@ -300,14 +249,14 @@ static bool canBeInlined(FunctionOp functionOp) } while (!nestedOps.empty()) { - mlir::Operation* nestedOp = nestedOps.pop_back_val(); + mlir::Operation *nestedOp = nestedOps.pop_back_val(); if (mlir::isa(nestedOp)) { return false; } - for (auto& nestedRegion : nestedOp->getRegions()) { - for (auto& nestedRegionOp : nestedRegion.getOps()) { + for (auto &nestedRegion : nestedOp->getRegions()) { + for (auto &nestedRegionOp : nestedRegion.getOps()) { nestedOps.push_back(&nestedRegionOp); } } @@ -316,533 +265,460 @@ static bool canBeInlined(FunctionOp functionOp) return true; } -namespace -{ - class VariablesDependencyGraph - { - public: - /// A node of the graph. - struct Node - { - Node() : graph(nullptr), variable(nullptr) - { - } - - Node(const VariablesDependencyGraph* graph, mlir::Operation* variable) - : graph(graph), variable(variable) - { - } - - bool operator==(const Node& other) const - { - return graph == other.graph && variable == other.variable; - } +namespace { +class VariablesDependencyGraph { +public: + /// A node of the graph. + struct Node { + Node() : graph(nullptr), variable(nullptr) {} - bool operator!=(const Node& other) const - { - return graph != other.graph || variable != other.variable; - } + Node(const VariablesDependencyGraph *graph, mlir::Operation *variable) + : graph(graph), variable(variable) {} - bool operator<(const Node& other) const - { - if (variable == nullptr) { - return true; - } + bool operator==(const Node &other) const { + return graph == other.graph && variable == other.variable; + } - if (other.variable == nullptr) { - return false; - } + bool operator!=(const Node &other) const { + return graph != other.graph || variable != other.variable; + } - return variable < other.variable; - } + bool operator<(const Node &other) const { + if (variable == nullptr) { + return true; + } - const VariablesDependencyGraph* graph; - mlir::Operation* variable; - }; + if (other.variable == nullptr) { + return false; + } - VariablesDependencyGraph() - { - // Entry node, which is connected to every other node. - nodes.emplace_back(this, nullptr); + return variable < other.variable; + } - // Ensure that the set of children for the entry node exists, even in - // case of no other nodes. - arcs[getEntryNode().variable] = {}; - } + const VariablesDependencyGraph *graph; + mlir::Operation *variable; + }; - virtual ~VariablesDependencyGraph() = default; + VariablesDependencyGraph() { + // Entry node, which is connected to every other node. + nodes.emplace_back(this, nullptr); - /// Add a group of variables to the graph and optionally enforce their - /// relative order to be preserved. - void addVariables(llvm::ArrayRef variables) - { - for (VariableOp variable : variables) { - assert(variable != nullptr); + // Ensure that the set of children for the entry node exists, even in + // case of no other nodes. + arcs[getEntryNode().variable] = {}; + } - Node node(this, variable.getOperation()); - nodes.push_back(node); - nodesByName[variable.getSymName()] = node; - arcs[node.variable] = {}; - } - } + virtual ~VariablesDependencyGraph() = default; - /// Discover the dependencies of the variables that have been added to - /// the graph. - void discoverDependencies() - { - for (const Node& node : nodes) { - if (node == getEntryNode()) { - continue; - } - - // Connect the entry node to the current one. - arcs[getEntryNode().variable].insert(node); - - // Connect the current node to the other nodes to which it depends. - mlir::Operation* variable = node.variable; - auto variableOp = mlir::cast(variable); - - for (llvm::StringRef dependency : getDependencies(variableOp)) { - auto& children = arcs[nodesByName[dependency].variable]; - children.insert(node); - } - } - } + /// Add a group of variables to the graph and optionally enforce their + /// relative order to be preserved. + void addVariables(llvm::ArrayRef variables) { + for (VariableOp variable : variables) { + assert(variable != nullptr); - /// Get the number of nodes. - /// Note that nodes consists in the inserted variables together with the - /// entry node. - size_t getNumOfNodes() const - { - return nodes.size(); - } + Node node(this, variable.getOperation()); + nodes.push_back(node); + nodesByName[variable.getSymName()] = node; + arcs[node.variable] = {}; + } + } - /// Get the entry node of the graph. - const Node& getEntryNode() const - { - return nodes[0]; + /// Discover the dependencies of the variables that have been added to + /// the graph. + void discoverDependencies() { + for (const Node &node : nodes) { + if (node == getEntryNode()) { + continue; } - /// @name Iterators for the children of a node - /// { + // Connect the entry node to the current one. + arcs[getEntryNode().variable].insert(node); - std::set::const_iterator childrenBegin(Node node) const - { - auto it = arcs.find(node.variable); - assert(it != arcs.end()); - const auto& children = it->second; - return children.begin(); - } + // Connect the current node to the other nodes to which it depends. + mlir::Operation *variable = node.variable; + auto variableOp = mlir::cast(variable); - std::set::const_iterator childrenEnd(Node node) const - { - auto it = arcs.find(node.variable); - assert(it != arcs.end()); - const auto& children = it->second; - return children.end(); + for (llvm::StringRef dependency : getDependencies(variableOp)) { + auto &children = arcs[nodesByName[dependency].variable]; + children.insert(node); } + } + } - /// } - /// @name Iterators for the nodes - /// { + /// Get the number of nodes. + /// Note that nodes consists in the inserted variables together with the + /// entry node. + size_t getNumOfNodes() const { return nodes.size(); } - llvm::SmallVector::const_iterator nodesBegin() const - { - return nodes.begin(); - } + /// Get the entry node of the graph. + const Node &getEntryNode() const { return nodes[0]; } - llvm::SmallVector::const_iterator nodesEnd() const - { - return nodes.end(); - } + /// @name Iterators for the children of a node + /// { - /// } + std::set::const_iterator childrenBegin(Node node) const { + auto it = arcs.find(node.variable); + assert(it != arcs.end()); + const auto &children = it->second; + return children.begin(); + } - /// Check if the graph contains cycles. - bool hasCycles() const; + std::set::const_iterator childrenEnd(Node node) const { + auto it = arcs.find(node.variable); + assert(it != arcs.end()); + const auto &children = it->second; + return children.end(); + } - /// Perform a post-order visit of the graph and get the ordered - /// variables. - llvm::SmallVector postOrder() const; + /// } + /// @name Iterators for the nodes + /// { - /// Perform a reverse post-order visit of the graph and get the ordered - /// variables. - llvm::SmallVector reversePostOrder() const; + llvm::SmallVector::const_iterator nodesBegin() const { + return nodes.begin(); + } - protected: - virtual std::set getDependencies( - VariableOp variable) = 0; + llvm::SmallVector::const_iterator nodesEnd() const { + return nodes.end(); + } - private: - llvm::DenseMap> arcs; - llvm::SmallVector nodes; - llvm::DenseMap nodesByName; - }; + /// } - /// Directed graph representing the dependencies among the variables with - /// respect to the usage of variables for the computation of the default - /// value. - class DefaultValuesGraph : public VariablesDependencyGraph - { - public: - explicit DefaultValuesGraph(const llvm::StringMap& defaultOps) - : defaultOps(&defaultOps) - { - } + /// Check if the graph contains cycles. + bool hasCycles() const; - protected: - std::set getDependencies(VariableOp variable) override; + /// Perform a post-order visit of the graph and get the ordered + /// variables. + llvm::SmallVector postOrder() const; - private: - const llvm::StringMap* defaultOps; - }; -} + /// Perform a reverse post-order visit of the graph and get the ordered + /// variables. + llvm::SmallVector reversePostOrder() const; -namespace llvm -{ - template<> - struct DenseMapInfo<::VariablesDependencyGraph::Node> - { - static inline ::VariablesDependencyGraph::Node getEmptyKey() - { - return {nullptr, nullptr}; - } +protected: + virtual std::set getDependencies(VariableOp variable) = 0; - static inline ::VariablesDependencyGraph::Node getTombstoneKey() - { - return {nullptr, nullptr}; - } +private: + llvm::DenseMap> arcs; + llvm::SmallVector nodes; + llvm::DenseMap nodesByName; +}; - static unsigned getHashValue(const ::VariablesDependencyGraph::Node& val) - { - return std::hash{}(val.variable); - } +/// Directed graph representing the dependencies among the variables with +/// respect to the usage of variables for the computation of the default +/// value. +class DefaultValuesGraph : public VariablesDependencyGraph { +public: + explicit DefaultValuesGraph(const llvm::StringMap &defaultOps) + : defaultOps(&defaultOps) {} - static bool isEqual( - const ::VariablesDependencyGraph::Node& lhs, - const ::VariablesDependencyGraph::Node& rhs) - { - return lhs.graph == rhs.graph && lhs.variable == rhs.variable; - } - }; +protected: + std::set getDependencies(VariableOp variable) override; - template<> - struct GraphTraits - { - using GraphType = const ::VariablesDependencyGraph*; - using NodeRef = ::VariablesDependencyGraph::Node; +private: + const llvm::StringMap *defaultOps; +}; +} // namespace - using ChildIteratorType = - std::set<::VariablesDependencyGraph::Node>::const_iterator; +namespace llvm { +template <> +struct DenseMapInfo<::VariablesDependencyGraph::Node> { + static inline ::VariablesDependencyGraph::Node getEmptyKey() { + return {nullptr, nullptr}; + } - static NodeRef getEntryNode(const GraphType& graph) - { - return graph->getEntryNode(); - } + static inline ::VariablesDependencyGraph::Node getTombstoneKey() { + return {nullptr, nullptr}; + } - static ChildIteratorType child_begin(NodeRef node) - { - return node.graph->childrenBegin(node); - } + static unsigned getHashValue(const ::VariablesDependencyGraph::Node &val) { + return std::hash{}(val.variable); + } - static ChildIteratorType child_end(NodeRef node) - { - return node.graph->childrenEnd(node); - } + static bool isEqual(const ::VariablesDependencyGraph::Node &lhs, + const ::VariablesDependencyGraph::Node &rhs) { + return lhs.graph == rhs.graph && lhs.variable == rhs.variable; + } +}; - using nodes_iterator = llvm::SmallVector::const_iterator; +template <> +struct GraphTraits { + using GraphType = const ::VariablesDependencyGraph *; + using NodeRef = ::VariablesDependencyGraph::Node; - static nodes_iterator nodes_begin(GraphType* graph) - { - return (*graph)->nodesBegin(); - } + using ChildIteratorType = + std::set<::VariablesDependencyGraph::Node>::const_iterator; - static nodes_iterator nodes_end(GraphType* graph) - { - return (*graph)->nodesEnd(); - } + static NodeRef getEntryNode(const GraphType &graph) { + return graph->getEntryNode(); + } - // There is no need for a dedicated class for the arcs. - using EdgeRef = ::VariablesDependencyGraph::Node; + static ChildIteratorType child_begin(NodeRef node) { + return node.graph->childrenBegin(node); + } - using ChildEdgeIteratorType = - std::set<::VariablesDependencyGraph::Node>::const_iterator; + static ChildIteratorType child_end(NodeRef node) { + return node.graph->childrenEnd(node); + } - static ChildEdgeIteratorType child_edge_begin(NodeRef node) - { - return node.graph->childrenBegin(node); - } + using nodes_iterator = llvm::SmallVector::const_iterator; - static ChildEdgeIteratorType child_edge_end(NodeRef node) - { - return node.graph->childrenEnd(node); - } + static nodes_iterator nodes_begin(GraphType *graph) { + return (*graph)->nodesBegin(); + } - static NodeRef edge_dest(EdgeRef edge) - { - return edge; - } + static nodes_iterator nodes_end(GraphType *graph) { + return (*graph)->nodesEnd(); + } - static unsigned int size(GraphType* graph) - { - return (*graph)->getNumOfNodes(); - } - }; -} + // There is no need for a dedicated class for the arcs. + using EdgeRef = ::VariablesDependencyGraph::Node; -namespace -{ - bool VariablesDependencyGraph::hasCycles() const - { - auto beginIt = llvm::scc_begin(this); - auto endIt = llvm::scc_end(this); + using ChildEdgeIteratorType = + std::set<::VariablesDependencyGraph::Node>::const_iterator; - for (auto it = beginIt; it != endIt; ++it) { - if (it.hasCycle()) { - return true; - } - } + static ChildEdgeIteratorType child_edge_begin(NodeRef node) { + return node.graph->childrenBegin(node); + } - return false; + static ChildEdgeIteratorType child_edge_end(NodeRef node) { + return node.graph->childrenEnd(node); } - llvm::SmallVector VariablesDependencyGraph::postOrder() const - { - assert(!hasCycles()); + static NodeRef edge_dest(EdgeRef edge) { return edge; } - llvm::SmallVector result; - std::set set; + static unsigned int size(GraphType *graph) { + return (*graph)->getNumOfNodes(); + } +}; +} // namespace llvm - for (const auto& node : llvm::post_order_ext(this, set)) { - if (node != getEntryNode()) { - result.push_back(mlir::cast(node.variable)); - } - } +namespace { +[[maybe_unused]] bool VariablesDependencyGraph::hasCycles() const { + auto beginIt = llvm::scc_begin(this); + auto endIt = llvm::scc_end(this); - return result; + for (auto it = beginIt; it != endIt; ++it) { + if (it.hasCycle()) { + return true; + } } - llvm::SmallVector VariablesDependencyGraph::reversePostOrder() const - { - auto result = postOrder(); - std::reverse(result.begin(), result.end()); - return result; - } + return false; +} - std::set DefaultValuesGraph::getDependencies( - VariableOp variable) - { - std::set dependencies; - auto defaultOpIt = defaultOps->find(variable.getSymName()); +llvm::SmallVector VariablesDependencyGraph::postOrder() const { + assert(!hasCycles()); - if (defaultOpIt != defaultOps->end()) { - DefaultOp defaultOp = defaultOpIt->getValue(); + llvm::SmallVector result; + std::set set; - defaultOp->walk([&](VariableGetOp getOp) { - dependencies.insert(getOp.getVariable()); - }); + for (const auto &node : llvm::post_order_ext(this, set)) { + if (node != getEntryNode()) { + result.push_back(mlir::cast(node.variable)); } - - return dependencies; } + + return result; } -namespace -{ - class DefaultOpComputationOrderings - { - public: - llvm::ArrayRef get(FunctionOp functionOp) const - { - auto it = orderings.find(functionOp); +llvm::SmallVector +VariablesDependencyGraph::reversePostOrder() const { + auto result = postOrder(); + std::reverse(result.begin(), result.end()); + return result; +} - // If the assertion doesn't hold, then verification is wrong. - assert(it != orderings.end()); +std::set +DefaultValuesGraph::getDependencies(VariableOp variable) { + std::set dependencies; + auto defaultOpIt = defaultOps->find(variable.getSymName()); - return it->getSecond(); - } + if (defaultOpIt != defaultOps->end()) { + DefaultOp defaultOp = defaultOpIt->getValue(); - void set( - FunctionOp functionOp, - llvm::ArrayRef variablesOrder) - { - for (VariableOp variableOp : variablesOrder) { - orderings[functionOp].push_back(variableOp); - } - } + defaultOp->walk( + [&](VariableGetOp getOp) { dependencies.insert(getOp.getVariable()); }); + } - private: - llvm::DenseMap> orderings; - }; + return dependencies; } +} // namespace -class FunctionInliner : public mlir::OpRewritePattern -{ - public: - FunctionInliner( - mlir::MLIRContext* context, - mlir::SymbolTableCollection& symbolTable, - const llvm::DenseSet& inlinableFunctions, - const DefaultOpComputationOrderings& orderings) - : mlir::OpRewritePattern(context), - symbolTable(&symbolTable), - inlinableFunctions(&inlinableFunctions), - orderings(&orderings) - { - } +namespace { +class DefaultOpComputationOrderings { +public: + llvm::ArrayRef get(FunctionOp functionOp) const { + auto it = orderings.find(functionOp); - mlir::LogicalResult matchAndRewrite( - CallOp op, mlir::PatternRewriter& rewriter) const override - { - auto moduleOp = op->getParentOfType(); + // If the assertion doesn't hold, then verification is wrong. + assert(it != orderings.end()); - FunctionOp callee = mlir::cast( - op.getFunction(moduleOp, *symbolTable)); + return it->getSecond(); + } - if (!inlinableFunctions->contains(callee)) { - return mlir::failure(); - } + void set(FunctionOp functionOp, llvm::ArrayRef variablesOrder) { + for (VariableOp variableOp : variablesOrder) { + orderings[functionOp].push_back(variableOp); + } + } - mlir::IRMapping mapping; - llvm::StringMap varMapping; +private: + llvm::DenseMap> orderings; +}; +} // namespace - // Map the operations providing the default values for the variables. - llvm::DenseMap defaultOps; +class FunctionInliner : public mlir::OpRewritePattern { +public: + FunctionInliner(mlir::MLIRContext *context, + mlir::SymbolTableCollection &symbolTable, + const llvm::DenseSet &inlinableFunctions, + const DefaultOpComputationOrderings &orderings) + : mlir::OpRewritePattern(context), symbolTable(&symbolTable), + inlinableFunctions(&inlinableFunctions), orderings(&orderings) {} - for (DefaultOp defaultOp : callee.getDefaultValues()) { - VariableOp variableOp = symbolTable->lookupSymbolIn( - callee, defaultOp.getVariableAttr()); + mlir::LogicalResult + matchAndRewrite(CallOp op, mlir::PatternRewriter &rewriter) const override { + auto moduleOp = op->getParentOfType(); - defaultOps[variableOp] = defaultOp; - } + FunctionOp callee = + mlir::cast(op.getFunction(moduleOp, *symbolTable)); - // Set the default values for variables. - for (VariableOp variableOp : orderings->get(callee)) { - auto defaultOpIt = defaultOps.find(variableOp); + if (!inlinableFunctions->contains(callee)) { + return mlir::failure(); + } - if (defaultOpIt == defaultOps.end()) { - continue; - } + mlir::IRMapping mapping; + llvm::StringMap varMapping; - DefaultOp defaultOp = defaultOpIt->getSecond(); + // Map the operations providing the default values for the variables. + llvm::DenseMap defaultOps; - for (auto& nestedOp : defaultOp.getOps()) { - if (auto yieldOp = mlir::dyn_cast(nestedOp)) { - assert(yieldOp.getValues().size() == 1); + for (DefaultOp defaultOp : callee.getDefaultValues()) { + VariableOp variableOp = symbolTable->lookupSymbolIn( + callee, defaultOp.getVariableAttr()); - varMapping[variableOp.getSymName()] = - mapping.lookup(yieldOp.getValues()[0]); - } else { - rewriter.clone(nestedOp, mapping); - } - } + defaultOps[variableOp] = defaultOp; + } + + // Set the default values for variables. + for (VariableOp variableOp : orderings->get(callee)) { + auto defaultOpIt = defaultOps.find(variableOp); + + if (defaultOpIt == defaultOps.end()) { + continue; } - // Map the call arguments to the function input variables. - llvm::SmallVector inputVariables; + DefaultOp defaultOp = defaultOpIt->getSecond(); - for (VariableOp variableOp : callee.getVariables()) { - if (variableOp.isInput()) { - inputVariables.push_back(variableOp); + for (auto &nestedOp : defaultOp.getOps()) { + if (auto yieldOp = mlir::dyn_cast(nestedOp)) { + assert(yieldOp.getValues().size() == 1); + + varMapping[variableOp.getSymName()] = + mapping.lookup(yieldOp.getValues()[0]); + } else { + rewriter.clone(nestedOp, mapping); } } + } - assert(op.getArgs().size() <= inputVariables.size()); + // Map the call arguments to the function input variables. + llvm::SmallVector inputVariables; - for (const auto& callArg : llvm::enumerate(op.getArgs())) { - VariableOp variableOp = inputVariables[callArg.index()]; - varMapping[variableOp.getSymName()] = callArg.value(); + for (VariableOp variableOp : callee.getVariables()) { + if (variableOp.isInput()) { + inputVariables.push_back(variableOp); } + } - // Check that all the input variables have a value. - assert(llvm::all_of(inputVariables, [&](VariableOp variableOp) { - return varMapping.find(variableOp.getSymName()) != varMapping.end(); - })); + assert(op.getArgs().size() <= inputVariables.size()); - // Clone the function body. - for (AlgorithmOp algorithmOp : callee.getOps()) { - for (auto& originalOp : algorithmOp.getOps()) { - cloneBodyOp(rewriter, mapping, varMapping, &originalOp); - } - } + for (const auto &callArg : llvm::enumerate(op.getArgs())) { + VariableOp variableOp = inputVariables[callArg.index()]; + varMapping[variableOp.getSymName()] = callArg.value(); + } - // Determine the result values. - llvm::SmallVector outputVariables; + // Check that all the input variables have a value. + assert(llvm::all_of(inputVariables, [&](VariableOp variableOp) { + return varMapping.find(variableOp.getSymName()) != varMapping.end(); + })); - for (VariableOp variableOp : callee.getVariables()) { - if (variableOp.isOutput()) { - outputVariables.push_back(variableOp); - } + // Clone the function body. + for (AlgorithmOp algorithmOp : callee.getOps()) { + for (auto &originalOp : algorithmOp.getOps()) { + cloneBodyOp(rewriter, mapping, varMapping, &originalOp); } + } - assert(op.getResults().size() == outputVariables.size()); - llvm::SmallVector newResults; + // Determine the result values. + llvm::SmallVector outputVariables; - for (VariableOp variableOp : outputVariables) { - newResults.push_back(varMapping.lookup(variableOp.getSymName())); + for (VariableOp variableOp : callee.getVariables()) { + if (variableOp.isOutput()) { + outputVariables.push_back(variableOp); } + } + + assert(op.getResults().size() == outputVariables.size()); + llvm::SmallVector newResults; - rewriter.replaceOp(op, newResults); - return mlir::success(); + for (VariableOp variableOp : outputVariables) { + newResults.push_back(varMapping.lookup(variableOp.getSymName())); } - private: - void cloneBodyOp( - mlir::OpBuilder& builder, - mlir::IRMapping& mapping, - llvm::StringMap& varMapping, - mlir::Operation* op) const - { - if (auto variableGetOp = mlir::dyn_cast(op)) { - mlir::Value& mappedValue = varMapping[variableGetOp.getVariable()]; - mapping.map(variableGetOp.getResult(), mappedValue); - return; - } + rewriter.replaceOp(op, newResults); + return mlir::success(); + } - if (auto variableSetOp = mlir::dyn_cast(op)) { - varMapping[variableSetOp.getVariable()] = - mapping.lookup(variableSetOp.getValue()); +private: + void cloneBodyOp(mlir::OpBuilder &builder, mlir::IRMapping &mapping, + llvm::StringMap &varMapping, + mlir::Operation *op) const { + if (auto variableGetOp = mlir::dyn_cast(op)) { + mlir::Value &mappedValue = varMapping[variableGetOp.getVariable()]; + mapping.map(variableGetOp.getResult(), mappedValue); + return; + } - return; - } + if (auto variableSetOp = mlir::dyn_cast(op)) { + varMapping[variableSetOp.getVariable()] = + mapping.lookup(variableSetOp.getValue()); - mlir::Operation* clonedOp = builder.clone(*op, mapping); - mapping.map(op->getResults(), clonedOp->getResults()); + return; } - private: - mlir::SymbolTableCollection* symbolTable; - const llvm::DenseSet* inlinableFunctions; - const DefaultOpComputationOrderings* orderings; + mlir::Operation *clonedOp = builder.clone(*op, mapping); + mapping.map(op->getResults(), clonedOp->getResults()); + } + +private: + mlir::SymbolTableCollection *symbolTable; + const llvm::DenseSet *inlinableFunctions; + const DefaultOpComputationOrderings *orderings; }; -namespace -{ - class FunctionInliningPass - : public mlir::bmodelica::impl::FunctionInliningPassBase< - FunctionInliningPass> - { - public: - using FunctionInliningPassBase::FunctionInliningPassBase; - - void runOnOperation() override; - - private: - void collectGraphNodes( - CallGraph& callGraph, - DefaultOpComputationOrderings& orderings, - mlir::Operation* op) const; - - void collectGraphEdges( - CallGraph& callGraph, - mlir::SymbolTableCollection& symbolTable, - mlir::ModuleOp moduleOp, - mlir::Operation* op) const; - }; -} +namespace { +class FunctionInliningPass + : public mlir::bmodelica::impl::FunctionInliningPassBase< + FunctionInliningPass> { +public: + using FunctionInliningPassBase::FunctionInliningPassBase; + + void runOnOperation() override; + +private: + void collectGraphNodes(CallGraph &callGraph, + DefaultOpComputationOrderings &orderings, + mlir::Operation *op) const; -void FunctionInliningPass::runOnOperation() -{ + void collectGraphEdges(CallGraph &callGraph, + mlir::SymbolTableCollection &symbolTable, + mlir::ModuleOp moduleOp, mlir::Operation *op) const; +}; +} // namespace + +void FunctionInliningPass::runOnOperation() { mlir::ModuleOp moduleOp = getOperation(); mlir::SymbolTableCollection symbolTable; @@ -855,23 +731,21 @@ void FunctionInliningPass::runOnOperation() mlir::RewritePatternSet patterns(&getContext()); auto inlinableFunctions = callGraph.getInlinableFunctions(); - patterns.add( - &getContext(), symbolTable, inlinableFunctions, orderings); + patterns.add(&getContext(), symbolTable, inlinableFunctions, + orderings); mlir::GreedyRewriteConfig config; config.useTopDownTraversal = true; - if (mlir::failed(applyPatternsAndFoldGreedily( - moduleOp, std::move(patterns), config))) { - return signalPassFailure(); + if (mlir::failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), + config))) { + return signalPassFailure(); } } void FunctionInliningPass::collectGraphNodes( - CallGraph& callGraph, - DefaultOpComputationOrderings& orderings, - mlir::Operation* op) const -{ + CallGraph &callGraph, DefaultOpComputationOrderings &orderings, + mlir::Operation *op) const { if (auto functionOp = mlir::dyn_cast(op)) { if (canBeInlined(functionOp)) { callGraph.addNode(functionOp); @@ -879,15 +753,15 @@ void FunctionInliningPass::collectGraphNodes( llvm::StringMap defaultOps; for (DefaultOp defaultOp : functionOp.getDefaultValues()) { - defaultOps[defaultOp.getVariable()] = defaultOp; + defaultOps[defaultOp.getVariable()] = defaultOp; } llvm::SmallVector inputVariables; for (VariableOp variableOp : functionOp.getVariables()) { - if (variableOp.isInput()) { - inputVariables.push_back(variableOp); - } + if (variableOp.isInput()) { + inputVariables.push_back(variableOp); + } } DefaultValuesGraph defaultValuesGraph(defaultOps); @@ -898,24 +772,21 @@ void FunctionInliningPass::collectGraphNodes( } } - for (auto& region : op->getRegions()) { - for (auto& nested : region.getOps()) { + for (auto ®ion : op->getRegions()) { + for (auto &nested : region.getOps()) { collectGraphNodes(callGraph, orderings, &nested); } } } void FunctionInliningPass::collectGraphEdges( - CallGraph& callGraph, - mlir::SymbolTableCollection& symbolTable, - mlir::ModuleOp moduleOp, - mlir::Operation* op) const -{ + CallGraph &callGraph, mlir::SymbolTableCollection &symbolTable, + mlir::ModuleOp moduleOp, mlir::Operation *op) const { if (auto functionOp = mlir::dyn_cast(op)) { if (callGraph.hasNode(functionOp)) { functionOp.walk([&](CallOp callOp) { - FunctionOp callee = mlir::cast( - callOp.getFunction(moduleOp, symbolTable)); + FunctionOp callee = + mlir::cast(callOp.getFunction(moduleOp, symbolTable)); if (callGraph.hasNode(callee)) { callGraph.addEdge(functionOp, callee); @@ -924,17 +795,15 @@ void FunctionInliningPass::collectGraphEdges( } } - for (auto& region : op->getRegions()) { - for (auto& nested : region.getOps()) { + for (auto ®ion : op->getRegions()) { + for (auto &nested : region.getOps()) { collectGraphEdges(callGraph, symbolTable, moduleOp, &nested); } } } -namespace mlir::bmodelica -{ - std::unique_ptr createFunctionInliningPass() - { - return std::make_unique(); - } +namespace mlir::bmodelica { +std::unique_ptr createFunctionInliningPass() { + return std::make_unique(); } +} // namespace mlir::bmodelica diff --git a/lib/Dialect/BaseModelica/Transforms/IDA.cpp b/lib/Dialect/BaseModelica/Transforms/IDA.cpp index e28c3a055..3a59f8042 100644 --- a/lib/Dialect/BaseModelica/Transforms/IDA.cpp +++ b/lib/Dialect/BaseModelica/Transforms/IDA.cpp @@ -781,9 +781,6 @@ mlir::LogicalResult IDAInstance::addEquationsToIDA( return mlir::failure(); } - auto writtenVar = symbolTableCollection->lookupSymbolIn( - modelOp, writeAccess->getVariable()); - // Collect the independent variables for automatic differentiation. llvm::DenseSet independentVariables; diff --git a/lib/Dialect/BaseModelica/Transforms/RecordInlining.cpp b/lib/Dialect/BaseModelica/Transforms/RecordInlining.cpp index ac6e27f16..969876f4d 100644 --- a/lib/Dialect/BaseModelica/Transforms/RecordInlining.cpp +++ b/lib/Dialect/BaseModelica/Transforms/RecordInlining.cpp @@ -3,816 +3,773 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -namespace mlir::bmodelica -{ +namespace mlir::bmodelica { #define GEN_PASS_DEF_RECORDINLININGPASS #include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" -} +} // namespace mlir::bmodelica using namespace ::mlir::bmodelica; -static std::string getComposedComponentName( - llvm::StringRef record, llvm::StringRef component) -{ +static std::string getComposedComponentName(llvm::StringRef record, + llvm::StringRef component) { return record.str() + "." + component.str(); } -static std::string getComposedComponentName( - VariableOp record, VariableOp component) -{ +static std::string getComposedComponentName(VariableOp record, + VariableOp component) { return getComposedComponentName(record.getSymName(), component.getSymName()); } -namespace -{ - template - class RecordInliningPattern : public mlir::OpRewritePattern - { - public: - RecordInliningPattern( - mlir::MLIRContext* context, - mlir::ModuleOp moduleOp, - mlir::SymbolTableCollection& symbolTable) - : mlir::OpRewritePattern(context), - moduleOp(moduleOp), - symbolTable(&symbolTable) - { - } +namespace { +template +class RecordInliningPattern : public mlir::OpRewritePattern { +public: + RecordInliningPattern(mlir::MLIRContext *context, mlir::ModuleOp moduleOp, + mlir::SymbolTableCollection &symbolTable) + : mlir::OpRewritePattern(context), moduleOp(moduleOp), + symbolTable(&symbolTable) {} + +protected: + mlir::SymbolTableCollection &getSymbolTableCollection() const { + return *symbolTable; + } - protected: - mlir::SymbolTableCollection& getSymbolTableCollection() const - { - return *symbolTable; - } + RecordOp getRecordOp(RecordType recordType) const { + return mlir::cast( + recordType.getRecordOp(getSymbolTableCollection(), moduleOp)); + } - RecordOp getRecordOp(RecordType recordType) const - { - return mlir::cast( - recordType.getRecordOp(getSymbolTableCollection(), moduleOp)); + bool isRecordBased(mlir::Value value) const { + return isRecordBased(value.getType()); + } + + bool isRecordBased(mlir::Type type) const { + if (auto tensorType = type.dyn_cast()) { + return tensorType.getElementType().isa(); + } + + return type.isa(); + } + + void mergeShapes(llvm::SmallVectorImpl &result, + llvm::ArrayRef parent, + llvm::ArrayRef child) const { + result.clear(); + result.append(parent.begin(), parent.end()); + result.append(child.begin(), child.end()); + } + + mlir::LogicalResult replaceRecordGetters( + mlir::PatternRewriter &rewriter, + llvm::function_ref + replaceFn, + llvm::SmallVectorImpl &subscriptions, + mlir::Value usedValue, mlir::Operation *user) const { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(user); + + if (mlir::isa(user)) { + subscriptions.push_back(user); + + for (mlir::Value userResult : user->getResults()) { + for (mlir::Operation *nestedUser : + llvm::make_early_inc_range(userResult.getUsers())) { + if (mlir::failed(replaceRecordGetters(rewriter, replaceFn, + subscriptions, userResult, + nestedUser))) { + return mlir::failure(); + } + } } - bool isRecordBased(mlir::Value value) const - { - return isRecordBased(value.getType()); + if (user->use_empty()) { + rewriter.eraseOp(user); } - bool isRecordBased(mlir::Type type) const - { - if (auto tensorType = type.dyn_cast()) { - return tensorType.getElementType().isa(); - } + subscriptions.pop_back(); + return mlir::success(); + } + + if (auto componentGetOp = mlir::dyn_cast(user)) { + mlir::Value replacement = replaceFn(rewriter, componentGetOp.getLoc(), + componentGetOp.getComponentName()); - return type.isa(); + if (!replacement) { + return mlir::failure(); } - void mergeShapes( - llvm::SmallVectorImpl& result, - llvm::ArrayRef parent, - llvm::ArrayRef child) const - { - result.clear(); - result.append(parent.begin(), parent.end()); - result.append(child.begin(), child.end()); + replacement = applySubscriptions(rewriter, replacement, subscriptions); + + if (auto tensorType = replacement.getType().dyn_cast(); + tensorType && !tensorType.hasRank()) { + replacement = rewriter.create( + componentGetOp.getLoc(), replacement, std::nullopt); } - mlir::LogicalResult replaceRecordGetters( - mlir::PatternRewriter& rewriter, - llvm::function_ref replaceFn, - llvm::SmallVectorImpl& subscriptions, - mlir::Value usedValue, - mlir::Operation* user) const - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(user); - - if (mlir::isa(user)) { - subscriptions.push_back(user); - - for (mlir::Value userResult : user->getResults()) { - for (mlir::Operation* nestedUser : - llvm::make_early_inc_range(userResult.getUsers())) { - if (mlir::failed(replaceRecordGetters( - rewriter, replaceFn, subscriptions, userResult, - nestedUser))) { - return mlir::failure(); - } - } - } + rewriter.replaceOp(componentGetOp, replacement); + return mlir::success(); + } - if (user->use_empty()) { - rewriter.eraseOp(user); - } + if (auto callOp = mlir::dyn_cast(user)) { + auto newCallOp = + unpackCallArg(rewriter, callOp, usedValue, replaceFn, subscriptions); - subscriptions.pop_back(); - return mlir::success(); - } + rewriter.replaceOp(callOp, newCallOp->getResults()); + return mlir::success(); + } - if (auto componentGetOp = mlir::dyn_cast(user)) { - mlir::Value replacement = replaceFn( - rewriter, componentGetOp.getLoc(), - componentGetOp.getComponentName()); + return mlir::failure(); + } - if (!replacement) { - return mlir::failure(); - } + CallOp + unpackCallArg(mlir::OpBuilder &builder, CallOp callOp, mlir::Value arg, + llvm::function_ref + componentGetter, + llvm::ArrayRef subscriptions) const { + llvm::SmallVector newArgs; + llvm::SmallVector newArgNames; - replacement = applySubscriptions( - rewriter, replacement, subscriptions); + for (auto currentArg : llvm::enumerate(callOp.getArgs())) { + if (currentArg.value() == arg) { + auto recordType = mlir::cast(currentArg.value().getType()); + RecordOp recordOp = getRecordOp(recordType); + + for (VariableOp component : recordOp.getVariables()) { + mlir::Value componentValue = componentGetter( + builder, currentArg.value().getLoc(), component.getSymName()); + + componentValue = + applySubscriptions(builder, componentValue, subscriptions); if (auto tensorType = - replacement.getType().dyn_cast(); + componentValue.getType().dyn_cast(); tensorType && !tensorType.hasRank()) { - replacement = rewriter.create( - componentGetOp.getLoc(), replacement, std::nullopt); + componentValue = builder.create( + callOp.getLoc(), componentValue, std::nullopt); } - rewriter.replaceOp(componentGetOp, replacement); - return mlir::success(); - } + newArgs.push_back(componentValue); - if (auto callOp = mlir::dyn_cast(user)) { - auto newCallOp = unpackCallArg( - rewriter, callOp, usedValue, replaceFn, subscriptions); + if (auto argNames = callOp.getArgNames()) { + auto argName = (*argNames)[currentArg.index()] + .cast() + .getValue(); - rewriter.replaceOp(callOp, newCallOp->getResults()); - return mlir::success(); + auto composedName = mlir::FlatSymbolRefAttr::get( + builder.getContext(), + getComposedComponentName(argName, component.getSymName())); + + newArgNames.push_back(composedName); + } } + } else { + newArgs.push_back(currentArg.value()); - return mlir::failure(); + if (auto argNames = callOp.getArgNames()) { + newArgNames.push_back((*argNames)[currentArg.index()]); + } } + } - CallOp unpackCallArg( - mlir::OpBuilder& builder, - CallOp callOp, - mlir::Value arg, - llvm::function_ref componentGetter, - llvm::ArrayRef subscriptions) const - { - llvm::SmallVector newArgs; - llvm::SmallVector newArgNames; - - for (auto currentArg : llvm::enumerate(callOp.getArgs())) { - if (currentArg.value() == arg) { - auto recordType = mlir::cast(currentArg.value().getType()); - RecordOp recordOp = getRecordOp(recordType); - - for (VariableOp component : recordOp.getVariables()) { - mlir::Value componentValue = componentGetter( - builder, currentArg.value().getLoc(), - component.getSymName()); - - componentValue = applySubscriptions( - builder, componentValue, subscriptions); - - if (auto tensorType = - componentValue.getType().dyn_cast(); - tensorType && !tensorType.hasRank()) { - componentValue = builder.create( - callOp.getLoc(), componentValue, std::nullopt); - } - - newArgs.push_back(componentValue); - - if (auto argNames = callOp.getArgNames()) { - auto argName = (*argNames)[currentArg.index()] - .cast().getValue(); - - auto composedName = mlir::FlatSymbolRefAttr::get( - builder.getContext(), - getComposedComponentName(argName, component.getSymName())); - - newArgNames.push_back(composedName); - } - } - } else { - newArgs.push_back(currentArg.value()); - - if (auto argNames = callOp.getArgNames()) { - newArgNames.push_back((*argNames)[currentArg.index()]); - } - } - } + std::optional argNamesAttr = std::nullopt; - std::optional argNamesAttr = std::nullopt; + if (!newArgNames.empty()) { + argNamesAttr = builder.getArrayAttr(newArgNames); + } - if (!newArgNames.empty()) { - argNamesAttr = builder.getArrayAttr(newArgNames); - } + return builder.create(callOp.getLoc(), callOp.getCallee(), + callOp.getResultTypes(), newArgs, + argNamesAttr); + } - return builder.create( - callOp.getLoc(), - callOp.getCallee(), - callOp.getResultTypes(), - newArgs, - argNamesAttr); - } + mlir::Value + applySubscriptions(mlir::OpBuilder &builder, mlir::Value root, + llvm::ArrayRef subscriptions) const { + mlir::Value result = root; - mlir::Value applySubscriptions( - mlir::OpBuilder& builder, - mlir::Value root, - llvm::ArrayRef subscriptions) const - { - mlir::Value result = root; - - for (mlir::Operation* op : subscriptions) { - if (auto extractOp = mlir::dyn_cast(op)) { - int64_t rank = result.getType().cast().getRank(); - - auto numOfSubscripts = - static_cast(extractOp.getIndices().size()); - - if (numOfSubscripts == rank) { - result = builder.create( - extractOp.getLoc(), result, extractOp.getIndices()); - } else { - result = builder.create( - extractOp.getLoc(), result, extractOp.getIndices()); - } - } else if (auto viewOp = mlir::dyn_cast(op)) { - result = builder.create( - viewOp.getLoc(), result, viewOp.getSubscriptions()); - } - } + for (mlir::Operation *op : subscriptions) { + if (auto extractOp = mlir::dyn_cast(op)) { + int64_t rank = result.getType().cast().getRank(); - return result; - } + auto numOfSubscripts = + static_cast(extractOp.getIndices().size()); - private: - mlir::ModuleOp moduleOp; - mlir::SymbolTableCollection* symbolTable; - }; - - /// Unpack the assignment of a record value into multiple assignments - /// involving the components of the record variable. - class VariableSetOpUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - VariableSetOp op, mlir::PatternRewriter& rewriter) const override - { - mlir::Type valueType = op.getValue().getType(); - mlir::Type valueBaseType = valueType; - - if (auto tensorType = valueType.dyn_cast()) { - valueBaseType = tensorType.getElementType(); + if (numOfSubscripts == rank) { + result = builder.create(extractOp.getLoc(), result, + extractOp.getIndices()); + } else { + result = builder.create(extractOp.getLoc(), result, + extractOp.getIndices()); } + } else if (auto viewOp = mlir::dyn_cast(op)) { + result = builder.create(viewOp.getLoc(), result, + viewOp.getSubscriptions()); + } + } - auto recordType = valueBaseType.dyn_cast(); + return result; + } - if (!recordType) { - return mlir::failure(); - } +private: + mlir::ModuleOp moduleOp; + mlir::SymbolTableCollection *symbolTable; +}; - auto recordOp = getRecordOp(recordType); +/// Unpack the assignment of a record value into multiple assignments +/// involving the components of the record variable. +class VariableSetOpUnpackPattern : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - for (VariableOp recordComponentOp : recordOp.getVariables()) { - mlir::Value componentValue = rewriter.create( - op.getLoc(), - recordComponentOp.getVariableType().unwrap(), - op.getValue(), - recordComponentOp.getSymName()); + mlir::LogicalResult + matchAndRewrite(VariableSetOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Type valueType = op.getValue().getType(); + mlir::Type valueBaseType = valueType; - llvm::SmallVector newPath; + if (auto tensorType = valueType.dyn_cast()) { + valueBaseType = tensorType.getElementType(); + } - newPath.push_back( - mlir::FlatSymbolRefAttr::get(op.getVariableAttr())); + auto recordType = valueBaseType.dyn_cast(); - newPath.push_back(mlir::FlatSymbolRefAttr::get( - recordComponentOp.getSymNameAttr())); + if (!recordType) { + return mlir::failure(); + } - llvm::SmallVector subscripts; - llvm::SmallVector subscriptsAmounts; + auto recordOp = getRecordOp(recordType); - if (auto tensorType = valueType.dyn_cast()) { - mlir::Value unboundedRange = - rewriter.create(op.getLoc()); + for (VariableOp recordComponentOp : recordOp.getVariables()) { + mlir::Value componentValue = rewriter.create( + op.getLoc(), recordComponentOp.getVariableType().unwrap(), + op.getValue(), recordComponentOp.getSymName()); - subscripts.append(tensorType.getRank(), unboundedRange); - subscriptsAmounts.push_back(tensorType.getRank()); - } else { - subscriptsAmounts.push_back(0); - } + llvm::SmallVector newPath; - rewriter.create( - op.getLoc(), - rewriter.getArrayAttr(newPath), - subscripts, - rewriter.getI64ArrayAttr(subscriptsAmounts), - componentValue); - } + newPath.push_back(mlir::FlatSymbolRefAttr::get(op.getVariableAttr())); + + newPath.push_back( + mlir::FlatSymbolRefAttr::get(recordComponentOp.getSymNameAttr())); - rewriter.eraseOp(op); - return mlir::success(); + llvm::SmallVector subscripts; + llvm::SmallVector subscriptsAmounts; + + if (auto tensorType = valueType.dyn_cast()) { + mlir::Value unboundedRange = + rewriter.create(op.getLoc()); + + subscripts.append(tensorType.getRank(), unboundedRange); + subscriptsAmounts.push_back(tensorType.getRank()); + } else { + subscriptsAmounts.push_back(0); } - }; - - /// Unpack the assignment of a record component into multiple assignments - /// involving the components of the record variable. - /// Together with the above pattern, this enables the handling of assignments - /// of nested records. - class VariableComponentSetOpUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern - ::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - VariableComponentSetOp op, - mlir::PatternRewriter& rewriter) const override - { - mlir::Type valueType = op.getValue().getType(); - mlir::Type valueBaseType = valueType; - - if (auto tensorType = valueType.dyn_cast()) { - valueBaseType = tensorType.getElementType(); - } - auto recordType = valueBaseType.dyn_cast(); + rewriter.create( + op.getLoc(), rewriter.getArrayAttr(newPath), subscripts, + rewriter.getI64ArrayAttr(subscriptsAmounts), componentValue); + } - if (!recordType) { - return mlir::failure(); - } + rewriter.eraseOp(op); + return mlir::success(); + } +}; - auto recordOp = getRecordOp(recordType); - size_t pathLength = op.getPath().size(); +/// Unpack the assignment of a record component into multiple assignments +/// involving the components of the record variable. +/// Together with the above pattern, this enables the handling of assignments +/// of nested records. +class VariableComponentSetOpUnpackPattern + : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - for (VariableOp recordComponentOp : recordOp.getVariables()) { - mlir::Value componentValue = rewriter.create( - op.getLoc(), - recordComponentOp.getVariableType().unwrap(), - op.getValue(), - recordComponentOp.getSymName()); + mlir::LogicalResult + matchAndRewrite(VariableComponentSetOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Type valueType = op.getValue().getType(); + mlir::Type valueBaseType = valueType; - llvm::SmallVector newPath; + if (auto tensorType = valueType.dyn_cast()) { + valueBaseType = tensorType.getElementType(); + } - for (size_t component = 1; component < pathLength; ++component) { - newPath.push_back(op.getPath()[component]); - } + auto recordType = valueBaseType.dyn_cast(); - newPath.push_back(mlir::FlatSymbolRefAttr::get( - recordComponentOp.getSymNameAttr())); + if (!recordType) { + return mlir::failure(); + } - llvm::SmallVector subscripts; - llvm::SmallVector subscriptsAmounts; + auto recordOp = getRecordOp(recordType); + size_t pathLength = op.getPath().size(); - for (mlir::Value subscript : op.getSubscriptions()) { - subscripts.push_back(subscript); - } + for (VariableOp recordComponentOp : recordOp.getVariables()) { + mlir::Value componentValue = rewriter.create( + op.getLoc(), recordComponentOp.getVariableType().unwrap(), + op.getValue(), recordComponentOp.getSymName()); - for (mlir::IntegerAttr subscriptsAmount : - op.getSubscriptionsAmounts().getAsRange()) { - subscriptsAmounts.push_back(subscriptsAmount.getInt()); - } + llvm::SmallVector newPath; - if (auto tensorType = valueType.dyn_cast()) { - mlir::Value unboundedRange = - rewriter.create(op.getLoc()); + for (size_t component = 1; component < pathLength; ++component) { + newPath.push_back(op.getPath()[component]); + } - subscripts.append(tensorType.getRank(), unboundedRange); - subscriptsAmounts.push_back(tensorType.getRank()); - } else { - subscriptsAmounts.push_back(0); - } + newPath.push_back( + mlir::FlatSymbolRefAttr::get(recordComponentOp.getSymNameAttr())); - rewriter.create( - op.getLoc(), - rewriter.getArrayAttr(newPath), - subscripts, - rewriter.getI64ArrayAttr(subscriptsAmounts), - componentValue); - } + llvm::SmallVector subscripts; + llvm::SmallVector subscriptsAmounts; - rewriter.eraseOp(op); - return mlir::success(); + for (mlir::Value subscript : op.getSubscriptions()) { + subscripts.push_back(subscript); } - }; - - class EquationSideOpUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - EquationSideOp op, mlir::PatternRewriter& rewriter) const override - { - llvm::SmallVector newValues; - bool recordFound = false; - - for (mlir::Value value : op.getValues()) { - if (auto recordType = value.getType().dyn_cast()) { - auto recordOp = getRecordOp(recordType); - - for (VariableOp component : recordOp.getVariables()) { - auto componentGetOp = rewriter.create( - value.getLoc(), - component.getVariableType().unwrap(), - value, - component.getSymName()); - - newValues.push_back(componentGetOp.getResult()); - } - - recordFound = true; - } else { - newValues.push_back(value); - } - } - if (!recordFound) { - return mlir::failure(); - } + for (mlir::IntegerAttr subscriptsAmount : + op.getSubscriptionsAmounts().getAsRange()) { + subscriptsAmounts.push_back(subscriptsAmount.getInt()); + } + + if (auto tensorType = valueType.dyn_cast()) { + mlir::Value unboundedRange = + rewriter.create(op.getLoc()); - rewriter.replaceOpWithNewOp(op, newValues); - return mlir::success(); + subscripts.append(tensorType.getRank(), unboundedRange); + subscriptsAmounts.push_back(tensorType.getRank()); + } else { + subscriptsAmounts.push_back(0); } - }; - - /// Unpack record variables into their components. - class VariableOpUnpackPattern : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - VariableOp op, mlir::PatternRewriter& rewriter) const override - { - VariableType variableType = op.getVariableType(); - mlir::Type elementType = variableType.getElementType(); - - if (!elementType.isa()) { - // Not a record or an array of records. - return mlir::failure(); - } - // Create a variable for each component and map it for faster lookups. - auto recordType = elementType.cast(); - auto recordOp = getRecordOp(recordType); + rewriter.create( + op.getLoc(), rewriter.getArrayAttr(newPath), subscripts, + rewriter.getI64ArrayAttr(subscriptsAmounts), componentValue); + } - llvm::StringMap componentsMap; + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class EquationSideOpUnpackPattern + : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; + + mlir::LogicalResult + matchAndRewrite(EquationSideOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector newValues; + bool recordFound = false; + + for (mlir::Value value : op.getValues()) { + if (auto recordType = value.getType().dyn_cast()) { + auto recordOp = getRecordOp(recordType); for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector dimensions; + auto componentGetOp = rewriter.create( + value.getLoc(), component.getVariableType().unwrap(), value, + component.getSymName()); - // Use the shape of the original record variable. - for (int64_t dimension : op.getVariableType().getShape()) { - dimensions.push_back(dimension); - } + newValues.push_back(componentGetOp.getResult()); + } - // Append the dimensions of the component. - for (int64_t dimension : component.getVariableType().getShape()) { - dimensions.push_back(dimension); - } + recordFound = true; + } else { + newValues.push_back(value); + } + } - // Start from the original variable type in order to keep the - // modifiers. - auto componentVariableType = - variableType - .withShape(dimensions) - .withType(component.getVariableType().getElementType()); + if (!recordFound) { + return mlir::failure(); + } - // Create the variable for the component. - auto unpackedComponent = rewriter.create( - op.getLoc(), - getComposedComponentName(op, component), - componentVariableType); + rewriter.replaceOpWithNewOp(op, newValues); + return mlir::success(); + } +}; - componentsMap[component.getSymName()] = unpackedComponent; - } +/// Unpack record variables into their components. +class VariableOpUnpackPattern : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - // Replace the uses of the original record. - auto cls = op->getParentOfType(); + mlir::LogicalResult + matchAndRewrite(VariableOp op, + mlir::PatternRewriter &rewriter) const override { + VariableType variableType = op.getVariableType(); + mlir::Type elementType = variableType.getElementType(); - llvm::SmallVector startOps; - llvm::SmallVector defaultOps; - llvm::SmallVector bindingEquationOps; + if (!elementType.isa()) { + // Not a record or an array of records. + return mlir::failure(); + } - for (auto& bodyOp : cls->getRegion(0).getOps()) { - if (auto startOp = mlir::dyn_cast(bodyOp)) { - if (startOp.getVariable().getRootReference() != op.getSymName()) { - continue; - } + // Create a variable for each component and map it for faster lookups. + auto recordType = elementType.cast(); + auto recordOp = getRecordOp(recordType); - auto startVariableOp = - getSymbolTableCollection().lookupSymbolIn( - cls, startOp.getVariable().getRootReference()); + llvm::StringMap componentsMap; - if (!startVariableOp) { - continue; - } + for (VariableOp component : recordOp.getVariables()) { + VariableType componentVariableType = component.getVariableType(); + llvm::SmallVector dimensions; - mlir::Type startVariableElementType = - startVariableOp.getVariableType().getElementType(); + // Use the shape of the original record variable. + for (int64_t dimension : variableType.getShape()) { + dimensions.push_back(dimension); + } - if (!startVariableElementType.isa()) { - continue; - } + // Append the dimensions of the component. + for (int64_t dimension : componentVariableType.getShape()) { + dimensions.push_back(dimension); + } - startOps.push_back(startOp); - } + // Start from the original variable type in order to keep the + // modifiers. + componentVariableType = + variableType.withShape(dimensions) + .withType(component.getVariableType().getElementType()); - if (auto defaultOp = mlir::dyn_cast(bodyOp)) { - if (defaultOp.getVariable() == op.getSymName() && - defaultOp.getVariableOp(getSymbolTableCollection()) - .getVariableType().getElementType().isa()) { - defaultOps.push_back(defaultOp); - } - } + // Create the variable for the component. + auto unpackedComponent = rewriter.create( + op.getLoc(), getComposedComponentName(op, component), + componentVariableType); - if (auto bindingEquationOp = - mlir::dyn_cast(bodyOp)) { - if (bindingEquationOp.getVariable() == op.getSymName() && - bindingEquationOp.getVariableOp( - getSymbolTableCollection()).getVariableType() - .getElementType().isa()) { - bindingEquationOps.push_back(bindingEquationOp); - } - } + componentsMap[component.getSymName()] = unpackedComponent; + } + + // Replace the uses of the original record. + auto cls = op->getParentOfType(); + + llvm::SmallVector startOps; + llvm::SmallVector defaultOps; + llvm::SmallVector bindingEquationOps; + + for (auto &bodyOp : cls->getRegion(0).getOps()) { + if (auto startOp = mlir::dyn_cast(bodyOp)) { + if (startOp.getVariable().getRootReference() != op.getSymName()) { + continue; } - for (StartOp startOp : startOps) { - unpackStartOp(rewriter, startOp); + auto startVariableOp = + getSymbolTableCollection().lookupSymbolIn( + cls, startOp.getVariable().getRootReference()); + + if (!startVariableOp) { + continue; } - for (DefaultOp defaultOp : defaultOps) { - unpackDefaultOp(rewriter, defaultOp); + mlir::Type startVariableElementType = + startVariableOp.getVariableType().getElementType(); + + if (!startVariableElementType.isa()) { + continue; } - for (BindingEquationOp bindingEquationOp : bindingEquationOps) { - unpackBindingEquationOp(rewriter, bindingEquationOp); + startOps.push_back(startOp); + } + + if (auto defaultOp = mlir::dyn_cast(bodyOp)) { + if (defaultOp.getVariable() == op.getSymName() && + defaultOp.getVariableOp(getSymbolTableCollection()) + .getVariableType() + .getElementType() + .isa()) { + defaultOps.push_back(defaultOp); } + } - llvm::SmallVector getOps; - llvm::SmallVector setOps; - - cls->getRegion(0).walk([&](mlir::Operation* nestedOp) { - if (auto getOp = mlir::dyn_cast(nestedOp)) { - if (getOp.getVariable() == op.getSymName()) { - getOps.push_back(getOp); - } - } else if (auto setOp = mlir::dyn_cast( - nestedOp)) { - auto rootNameAttr = setOp.getPath()[0].cast(); - - if (rootNameAttr.getValue() == op.getSymName()) { - setOps.push_back(setOp); - } - } - }); + if (auto bindingEquationOp = mlir::dyn_cast(bodyOp)) { + if (bindingEquationOp.getVariable() == op.getSymName() && + bindingEquationOp.getVariableOp(getSymbolTableCollection()) + .getVariableType() + .getElementType() + .isa()) { + bindingEquationOps.push_back(bindingEquationOp); + } + } + } - for (VariableGetOp getOp : getOps) { - if (mlir::failed(replaceVariableGetOp( - rewriter, getOp, componentsMap))) { - return mlir::failure(); - } + for (StartOp startOp : startOps) { + unpackStartOp(rewriter, startOp); + } + + for (DefaultOp defaultOp : defaultOps) { + unpackDefaultOp(rewriter, defaultOp); + } + + for (BindingEquationOp bindingEquationOp : bindingEquationOps) { + unpackBindingEquationOp(rewriter, bindingEquationOp); + } + + llvm::SmallVector getOps; + llvm::SmallVector setOps; + + cls->getRegion(0).walk([&](mlir::Operation *nestedOp) { + if (auto getOp = mlir::dyn_cast(nestedOp)) { + if (getOp.getVariable() == op.getSymName()) { + getOps.push_back(getOp); } + } else if (auto setOp = + mlir::dyn_cast(nestedOp)) { + auto rootNameAttr = setOp.getPath()[0].cast(); - for (VariableComponentSetOp setOp : setOps) { - if (mlir::failed(replaceVariableComponentSetOp( - rewriter, op, setOp, componentsMap))) { - return mlir::failure(); - } + if (rootNameAttr.getValue() == op.getSymName()) { + setOps.push_back(setOp); } + } + }); - rewriter.eraseOp(op); - return mlir::success(); + for (VariableGetOp getOp : getOps) { + if (mlir::failed(replaceVariableGetOp(rewriter, getOp, componentsMap))) { + return mlir::failure(); } + } - private: - void unpackStartOp( - mlir::PatternRewriter& rewriter, - StartOp startOp) const - { - if (auto oldNestedRefs = startOp.getVariable().getNestedReferences(); - !oldNestedRefs.empty()) { - // The StartOp already goes through the components of the record. - std::string newRoot = getComposedComponentName( - startOp.getVariable().getRootReference(), - oldNestedRefs.front().getValue()); - - startOp.setVariableAttr(mlir::SymbolRefAttr::get( - rewriter.getStringAttr(newRoot), oldNestedRefs.drop_front())); - - return; - } + for (VariableComponentSetOp setOp : setOps) { + if (mlir::failed(replaceVariableComponentSetOp(rewriter, op, setOp, + componentsMap))) { + return mlir::failure(); + } + } - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(startOp); + rewriter.eraseOp(op); + return mlir::success(); + } - auto modelOp = startOp->getParentOfType(); +private: + void unpackStartOp(mlir::PatternRewriter &rewriter, StartOp startOp) const { + if (auto oldNestedRefs = startOp.getVariable().getNestedReferences(); + !oldNestedRefs.empty()) { + // The StartOp already goes through the components of the record. + std::string newRoot = + getComposedComponentName(startOp.getVariable().getRootReference(), + oldNestedRefs.front().getValue()); - auto variableOp = - getSymbolTableCollection().lookupSymbolIn( - modelOp, startOp.getVariable().getRootReference()); + startOp.setVariableAttr(mlir::SymbolRefAttr::get( + rewriter.getStringAttr(newRoot), oldNestedRefs.drop_front())); - auto recordType = variableOp.getVariableType() - .getElementType().cast(); + return; + } - auto recordOp = getRecordOp(recordType); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(startOp); - for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector shape; + auto modelOp = startOp->getParentOfType(); - mergeShapes( - shape, - variableOp.getVariableType().getShape(), - component.getVariableType().getShape()); + auto variableOp = getSymbolTableCollection().lookupSymbolIn( + modelOp, startOp.getVariable().getRootReference()); - auto clonedOp = mlir::cast( - rewriter.clone(*startOp.getOperation())); + auto recordType = + variableOp.getVariableType().getElementType().cast(); - clonedOp.setVariableAttr(mlir::SymbolRefAttr::get( - rewriter.getStringAttr( - getComposedComponentName(variableOp, component)))); + auto recordOp = getRecordOp(recordType); - auto yieldOp = mlir::cast( - clonedOp.getBody()->getTerminator()); + for (VariableOp component : recordOp.getVariables()) { + llvm::SmallVector shape; - rewriter.setInsertionPointAfter(yieldOp); + mergeShapes(shape, variableOp.getVariableType().getShape(), + component.getVariableType().getShape()); - mlir::Value componentValue = rewriter.create( - yieldOp.getLoc(), - component.getVariableType().withShape(shape).unwrap(), - yieldOp.getValues()[0], - component.getSymName()); + auto clonedOp = + mlir::cast(rewriter.clone(*startOp.getOperation())); - rewriter.replaceOpWithNewOp(yieldOp, componentValue); - rewriter.setInsertionPointAfter(clonedOp); - } + clonedOp.setVariableAttr(mlir::SymbolRefAttr::get(rewriter.getStringAttr( + getComposedComponentName(variableOp, component)))); - rewriter.eraseOp(startOp); - } + auto yieldOp = mlir::cast(clonedOp.getBody()->getTerminator()); - void unpackDefaultOp( - mlir::PatternRewriter& rewriter, - DefaultOp op) const - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(yieldOp); - VariableOp variableOp = op.getVariableOp(getSymbolTableCollection()); + mlir::Value componentValue = rewriter.create( + yieldOp.getLoc(), + component.getVariableType().withShape(shape).unwrap(), + yieldOp.getValues()[0], component.getSymName()); - auto recordType = variableOp.getVariableType() - .getElementType().cast(); + rewriter.replaceOpWithNewOp(yieldOp, componentValue); + rewriter.setInsertionPointAfter(clonedOp); + } - auto recordOp = getRecordOp(recordType); + rewriter.eraseOp(startOp); + } - for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector shape; + void unpackDefaultOp(mlir::PatternRewriter &rewriter, DefaultOp op) const { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); - mergeShapes( - shape, - variableOp.getVariableType().getShape(), - component.getVariableType().getShape()); + VariableOp variableOp = op.getVariableOp(getSymbolTableCollection()); - auto clonedOp = mlir::cast( - rewriter.clone(*op.getOperation())); + auto recordType = + variableOp.getVariableType().getElementType().cast(); - clonedOp.setVariable( - getComposedComponentName(variableOp, component)); + auto recordOp = getRecordOp(recordType); - auto yieldOp = mlir::cast( - clonedOp.getBody()->getTerminator()); + for (VariableOp component : recordOp.getVariables()) { + llvm::SmallVector shape; - rewriter.setInsertionPointAfter(yieldOp); + mergeShapes(shape, variableOp.getVariableType().getShape(), + component.getVariableType().getShape()); - mlir::Value componentValue = rewriter.create( - yieldOp.getLoc(), - component.getVariableType().withShape(shape).unwrap(), - yieldOp.getValues()[0], - component.getSymName()); + auto clonedOp = mlir::cast(rewriter.clone(*op.getOperation())); - rewriter.replaceOpWithNewOp(yieldOp, componentValue); - rewriter.setInsertionPointAfter(clonedOp); - } + clonedOp.setVariable(getComposedComponentName(variableOp, component)); - rewriter.eraseOp(op); - } + auto yieldOp = mlir::cast(clonedOp.getBody()->getTerminator()); - void unpackBindingEquationOp( - mlir::PatternRewriter& rewriter, - BindingEquationOp op) const - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(yieldOp); - VariableOp variableOp = op.getVariableOp(getSymbolTableCollection()); + mlir::Value componentValue = rewriter.create( + yieldOp.getLoc(), + component.getVariableType().withShape(shape).unwrap(), + yieldOp.getValues()[0], component.getSymName()); - auto recordType = variableOp.getVariableType() - .getElementType().cast(); + rewriter.replaceOpWithNewOp(yieldOp, componentValue); + rewriter.setInsertionPointAfter(clonedOp); + } - auto recordOp = getRecordOp(recordType); + rewriter.eraseOp(op); + } - for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector shape; + void unpackBindingEquationOp(mlir::PatternRewriter &rewriter, + BindingEquationOp op) const { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); - mergeShapes( - shape, - variableOp.getVariableType().getShape(), - component.getVariableType().getShape()); + VariableOp variableOp = op.getVariableOp(getSymbolTableCollection()); - auto clonedOp = mlir::cast( - rewriter.clone(*op.getOperation())); + auto recordType = + variableOp.getVariableType().getElementType().cast(); - clonedOp.setVariable( - getComposedComponentName(variableOp, component)); + auto recordOp = getRecordOp(recordType); - auto yieldOp = mlir::cast( - clonedOp.getBodyRegion().back().getTerminator()); + for (VariableOp component : recordOp.getVariables()) { + llvm::SmallVector shape; - rewriter.setInsertionPointAfter(yieldOp); + mergeShapes(shape, variableOp.getVariableType().getShape(), + component.getVariableType().getShape()); - mlir::Value componentValue = rewriter.create( - yieldOp.getLoc(), - component.getVariableType().withShape(shape).unwrap(), - yieldOp.getValues()[0], - component.getSymName()); + auto clonedOp = + mlir::cast(rewriter.clone(*op.getOperation())); - rewriter.replaceOpWithNewOp(yieldOp, componentValue); - rewriter.setInsertionPointAfter(clonedOp); - } + clonedOp.setVariable(getComposedComponentName(variableOp, component)); + + auto yieldOp = + mlir::cast(clonedOp.getBodyRegion().back().getTerminator()); - rewriter.eraseOp(op); + rewriter.setInsertionPointAfter(yieldOp); + + mlir::Value componentValue = rewriter.create( + yieldOp.getLoc(), + component.getVariableType().withShape(shape).unwrap(), + yieldOp.getValues()[0], component.getSymName()); + + rewriter.replaceOpWithNewOp(yieldOp, componentValue); + rewriter.setInsertionPointAfter(clonedOp); + } + + rewriter.eraseOp(op); + } + + mlir::LogicalResult + replaceVariableGetOp(mlir::PatternRewriter &rewriter, VariableGetOp getOp, + const llvm::StringMap &componentsMap) const { + llvm::SmallVector subscriptions; + + auto componentGetter = [&](mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef componentName) -> mlir::Value { + auto componentIt = componentsMap.find(componentName); + + if (componentIt == componentsMap.end()) { + return nullptr; } - mlir::LogicalResult replaceVariableGetOp( - mlir::PatternRewriter& rewriter, - VariableGetOp getOp, - const llvm::StringMap& componentsMap) const - { - llvm::SmallVector subscriptions; + return builder.create(loc, componentIt->getValue()); + }; - auto componentGetter = - [&](mlir::OpBuilder& builder, - mlir::Location loc, - llvm::StringRef componentName) -> mlir::Value { - auto componentIt = componentsMap.find(componentName); + for (mlir::Operation *user : + llvm::make_early_inc_range(getOp->getUsers())) { + if (mlir::failed(replaceRecordGetters(rewriter, componentGetter, + subscriptions, getOp.getResult(), + user))) { + return mlir::failure(); + } + } - if (componentIt == componentsMap.end()) { - return nullptr; - } + rewriter.eraseOp(getOp); + return mlir::success(); + } - return builder.create(loc, componentIt->getValue()); - }; + mlir::LogicalResult replaceVariableComponentSetOp( + mlir::PatternRewriter &rewriter, VariableOp variableOp, + VariableComponentSetOp setOp, + const llvm::StringMap &componentsMap) const { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(setOp); - for (mlir::Operation* user : - llvm::make_early_inc_range(getOp->getUsers())) { - if (mlir::failed(replaceRecordGetters( - rewriter, componentGetter, subscriptions, - getOp.getResult(), user))) { - return mlir::failure(); - } - } + int64_t rootVariableRank = variableOp.getVariableType().getRank(); + size_t pathLength = setOp.getPath().size(); - rewriter.eraseOp(getOp); - return mlir::success(); + if (pathLength > 2) { + std::string composedName = getComposedComponentName( + setOp.getPath()[0].cast().getValue(), + setOp.getPath()[1].cast().getValue()); + + llvm::SmallVector destination; + + destination.push_back( + mlir::FlatSymbolRefAttr::get(rewriter.getContext(), composedName)); + + for (size_t i = 2; i < pathLength; ++i) { + destination.push_back(setOp.getPath()[i]); } - mlir::LogicalResult replaceVariableComponentSetOp( - mlir::PatternRewriter& rewriter, - VariableOp variableOp, - VariableComponentSetOp setOp, - const llvm::StringMap& componentsMap) const - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(setOp); + llvm::SmallVector subscripts; + llvm::SmallVector subscriptsAmounts; + + subscriptsAmounts.push_back(rootVariableRank); - int64_t rootVariableRank = variableOp.getVariableType().getRank(); - size_t pathLength = setOp.getPath().size(); + getFullRankSubscripts(rewriter, setOp.getLoc(), rootVariableRank, + setOp.getComponentSubscripts(0), subscripts); - if (pathLength > 2) { - std::string composedName = getComposedComponentName( - setOp.getPath()[0].cast().getValue(), - setOp.getPath()[1].cast().getValue()); + for (size_t component = 1; component < pathLength; ++component) { + auto componentSubscripts = setOp.getComponentSubscripts(component); - llvm::SmallVector destination; + subscripts.append(componentSubscripts.begin(), + componentSubscripts.end()); + } - destination.push_back(mlir::FlatSymbolRefAttr::get( - rewriter.getContext(), composedName)); + for (mlir::IntegerAttr subscriptsAmount : + setOp.getSubscriptionsAmounts().getAsRange()) { + subscriptsAmounts.push_back(subscriptsAmount.getInt()); + } - for (size_t i = 2; i < pathLength; ++i) { - destination.push_back(setOp.getPath()[i]); - } + rewriter.create( + setOp.getLoc(), rewriter.getArrayAttr(destination), subscripts, + rewriter.getI64ArrayAttr(subscriptsAmounts), setOp.getValue()); + } else { + auto componentName = setOp.getPath()[1].cast(); - llvm::SmallVector subscripts; - llvm::SmallVector subscriptsAmounts; + if (!componentsMap.contains(componentName.getValue())) { + return mlir::failure(); + } + + VariableOp componentVariableOp = + componentsMap.lookup(componentName.getValue()); + + auto subscriptions = setOp.getSubscriptions(); - subscriptsAmounts.push_back(rootVariableRank); + if (setOp.getValue().getType().isa()) { + if (subscriptions.empty()) { + rewriter.create(setOp.getLoc(), componentVariableOp, + setOp.getValue()); + } else { + mlir::Value previousValue = rewriter.create( + setOp.getLoc(), componentVariableOp); + + llvm::SmallVector subscripts; - getFullRankSubscripts( - rewriter, setOp.getLoc(), - rootVariableRank, setOp.getComponentSubscripts(0), - subscripts); + getFullRankSubscripts(rewriter, setOp.getLoc(), rootVariableRank, + setOp.getComponentSubscripts(0), subscripts); for (size_t component = 1; component < pathLength; ++component) { auto componentSubscripts = setOp.getComponentSubscripts(component); @@ -821,503 +778,424 @@ namespace componentSubscripts.end()); } - for (mlir::IntegerAttr subscriptsAmount : - setOp.getSubscriptionsAmounts().getAsRange()) { - subscriptsAmounts.push_back(subscriptsAmount.getInt()); - } + mlir::Value newValue = rewriter.create( + setOp.getLoc(), setOp.getValue(), previousValue, subscripts); - rewriter.create( - setOp.getLoc(), - rewriter.getArrayAttr(destination), - subscripts, - rewriter.getI64ArrayAttr(subscriptsAmounts), - setOp.getValue()); + rewriter.create(setOp.getLoc(), componentVariableOp, + newValue); + } + } else { + if (subscriptions.empty()) { + rewriter.create(setOp.getLoc(), componentVariableOp, + setOp.getValue()); } else { - auto componentName = - setOp.getPath()[1].cast(); + mlir::Value previousValue = rewriter.create( + setOp.getLoc(), componentVariableOp); - if (!componentsMap.contains(componentName.getValue())) { - return mlir::failure(); - } + mlir::Value newValue = rewriter.create( + setOp.getLoc(), setOp.getValue(), previousValue, subscriptions); - VariableOp componentVariableOp = - componentsMap.lookup(componentName.getValue()); - - auto subscriptions = setOp.getSubscriptions(); - - if (setOp.getValue().getType().isa()) { - if (subscriptions.empty()) { - rewriter.create( - setOp.getLoc(), componentVariableOp, setOp.getValue()); - } else { - mlir::Value previousValue = rewriter.create( - setOp.getLoc(), componentVariableOp); - - llvm::SmallVector subscripts; - - getFullRankSubscripts( - rewriter, setOp.getLoc(), - rootVariableRank, setOp.getComponentSubscripts(0), - subscripts); - - for (size_t component = 1; component < pathLength; ++component) { - auto componentSubscripts = setOp.getComponentSubscripts(component); - - subscripts.append(componentSubscripts.begin(), - componentSubscripts.end()); - } - - mlir::Value newValue = rewriter.create( - setOp.getLoc(), setOp.getValue(), previousValue, subscripts); - - rewriter.create( - setOp.getLoc(), componentVariableOp, newValue); - } - } else { - if (subscriptions.empty()) { - rewriter.create( - setOp.getLoc(), componentVariableOp, setOp.getValue()); - } else { - mlir::Value previousValue = rewriter.create( - setOp.getLoc(), componentVariableOp); - - mlir::Value newValue = rewriter.create( - setOp.getLoc(), setOp.getValue(), previousValue, - subscriptions); - - rewriter.create( - setOp.getLoc(), componentVariableOp, newValue); - } - } + rewriter.create(setOp.getLoc(), componentVariableOp, + newValue); } - - rewriter.eraseOp(setOp); - return mlir::success(); } + } + + rewriter.eraseOp(setOp); + return mlir::success(); + } + + void getFullRankSubscripts(mlir::OpBuilder &builder, mlir::Location loc, + int64_t rank, mlir::ValueRange givenSubscripts, + llvm::SmallVectorImpl &result) const { + size_t numOfGivenSubscripts = givenSubscripts.size(); + result.append(givenSubscripts.begin(), givenSubscripts.end()); + + int64_t numOfAdditionalSubscripts = + rank - static_cast(numOfGivenSubscripts); - void getFullRankSubscripts( - mlir::OpBuilder& builder, - mlir::Location loc, - int64_t rank, - mlir::ValueRange givenSubscripts, - llvm::SmallVectorImpl& result) const - { - size_t numOfGivenSubscripts = givenSubscripts.size(); - result.append(givenSubscripts.begin(), givenSubscripts.end()); - - int64_t numOfAdditionalSubscripts = - rank - static_cast(numOfGivenSubscripts); - - for (int64_t i = 0; i < numOfAdditionalSubscripts; ++i) { - result.push_back(builder.create(loc)); + for (int64_t i = 0; i < numOfAdditionalSubscripts; ++i) { + result.push_back(builder.create(loc)); + } + } +}; + +class CallResultUnpackPattern : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; + + mlir::LogicalResult + matchAndRewrite(CallOp op, mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector newResultTypes; + llvm::DenseMap> components; + + for (auto result : llvm::enumerate(op.getResults())) { + mlir::Type resultType = result.value().getType(); + + llvm::SmallVector unpackedNames; + llvm::SmallVector unpackedTypes; + + if (auto componentsCount = + unpackResultType(resultType, unpackedNames, unpackedTypes); + componentsCount > 0) { + for (size_t i = 0; i < componentsCount; ++i) { + components[result.index()][unpackedNames[i]] = newResultTypes.size(); + newResultTypes.push_back(unpackedTypes[i]); } + } else { + newResultTypes.push_back(resultType); } - }; - - class CallResultUnpackPattern : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - CallOp op, mlir::PatternRewriter& rewriter) const override - { - llvm::SmallVector newResultTypes; - llvm::DenseMap> components; - - for (auto result : llvm::enumerate(op.getResults())) { - mlir::Type resultType = result.value().getType(); - - llvm::SmallVector unpackedNames; - llvm::SmallVector unpackedTypes; - - if (auto componentsCount = unpackResultType( - resultType, unpackedNames, unpackedTypes); - componentsCount > 0) { - for (size_t i = 0; i < componentsCount; ++i) { - components[result.index()][unpackedNames[i]] = newResultTypes.size(); - newResultTypes.push_back(unpackedTypes[i]); - } - } else { - newResultTypes.push_back(resultType); - } - } + } - if (components.empty()) { - return mlir::failure(); - } + if (components.empty()) { + return mlir::failure(); + } + + auto newCallOp = + rewriter.create(op.getLoc(), op.getCallee(), newResultTypes, + op.getArgs(), op.getArgNames()); + + size_t newResultsCounter = 0; + + for (auto oldResult : llvm::enumerate(op.getResults())) { + if (isRecordBased(oldResult.value())) { + llvm::SmallVector subscriptions; + + auto componentGetter = + [&](mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef componentName) -> mlir::Value { + return newCallOp.getResult( + components[oldResult.index()][componentName]); + }; - auto newCallOp = rewriter.create( - op.getLoc(), - op.getCallee(), newResultTypes, op.getArgs(), op.getArgNames()); - - size_t newResultsCounter = 0; - - for (auto oldResult : llvm::enumerate(op.getResults())) { - if (isRecordBased(oldResult.value())) { - llvm::SmallVector subscriptions; - - auto componentGetter = - [&](mlir::OpBuilder& builder, - mlir::Location loc, - llvm::StringRef componentName) -> mlir::Value { - return newCallOp.getResult( - components[oldResult.index()][componentName]); - }; - - for (mlir::Operation* user : - llvm::make_early_inc_range(oldResult.value().getUsers())) { - if (mlir::failed(replaceRecordGetters( - rewriter, componentGetter, - subscriptions, oldResult.value(), user))) { - return mlir::failure(); - } - } - } else { - oldResult.value().replaceAllUsesWith( - newCallOp.getResult(newResultsCounter++)); + for (mlir::Operation *user : + llvm::make_early_inc_range(oldResult.value().getUsers())) { + if (mlir::failed(replaceRecordGetters(rewriter, componentGetter, + subscriptions, + oldResult.value(), user))) { + return mlir::failure(); } } - - rewriter.eraseOp(op); - return mlir::success(); + } else { + oldResult.value().replaceAllUsesWith( + newCallOp.getResult(newResultsCounter++)); } + } - private: - size_t unpackResultType( - mlir::Type resultType, - llvm::SmallVectorImpl& unpackedNames, - llvm::SmallVectorImpl& unpackedTypes) const - { - size_t result = 0; - mlir::Type baseType = resultType; - - if (auto tensorType = resultType.dyn_cast()) { - baseType = tensorType.getElementType(); - } + rewriter.eraseOp(op); + return mlir::success(); + } - auto recordType = baseType.dyn_cast(); +private: + size_t + unpackResultType(mlir::Type resultType, + llvm::SmallVectorImpl &unpackedNames, + llvm::SmallVectorImpl &unpackedTypes) const { + size_t result = 0; + mlir::Type baseType = resultType; - if (!recordType) { - return result; - } + if (auto tensorType = resultType.dyn_cast()) { + baseType = tensorType.getElementType(); + } - llvm::SmallVector baseDimensions; + auto recordType = baseType.dyn_cast(); - if (auto tensorType = resultType.dyn_cast()) { - auto shape = tensorType.getShape(); - baseDimensions.append(shape.begin(), shape.end()); - } + if (!recordType) { + return result; + } - auto recordOp = getRecordOp(recordType); - llvm::SmallVector dimensions; + llvm::SmallVector baseDimensions; - for (VariableOp component : recordOp.getVariables()) { - unpackedNames.push_back(component.getSymNameAttr()); + if (auto tensorType = resultType.dyn_cast()) { + auto shape = tensorType.getShape(); + baseDimensions.append(shape.begin(), shape.end()); + } - dimensions.clear(); - dimensions.append(baseDimensions); + auto recordOp = getRecordOp(recordType); + llvm::SmallVector dimensions; - auto variableType = component.getVariableType(); - auto shape = variableType.getShape(); - dimensions.append(shape.begin(), shape.end()); + for (VariableOp component : recordOp.getVariables()) { + unpackedNames.push_back(component.getSymNameAttr()); - if (dimensions.empty()) { - unpackedTypes.push_back(variableType.unwrap()); - } else { - unpackedTypes.push_back(mlir::RankedTensorType::get( - dimensions, variableType.unwrap())); - } + dimensions.clear(); + dimensions.append(baseDimensions); - ++result; - } + auto variableType = component.getVariableType(); + auto shape = variableType.getShape(); + dimensions.append(shape.begin(), shape.end()); - return result; + if (dimensions.empty()) { + unpackedTypes.push_back(variableType.unwrap()); + } else { + unpackedTypes.push_back( + mlir::RankedTensorType::get(dimensions, variableType.unwrap())); } - }; - - class RecordCreateOpUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - RecordCreateOp op, mlir::PatternRewriter& rewriter) const override - { - auto recordType = op.getResult().getType().cast(); - auto recordOp = getRecordOp(recordType); - llvm::SmallVector componentGetOps; + ++result; + } - for (mlir::Operation* user : - llvm::make_early_inc_range(op->getUsers())) { - if (auto getOp = mlir::dyn_cast(user)) { - componentGetOps.push_back(getOp); - } - } + return result; + } +}; - if (componentGetOps.empty()) { - return mlir::failure(); - } +class RecordCreateOpUnpackPattern + : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - llvm::StringMap componentsMap; + mlir::LogicalResult + matchAndRewrite(RecordCreateOp op, + mlir::PatternRewriter &rewriter) const override { + auto recordType = op.getResult().getType().cast(); + auto recordOp = getRecordOp(recordType); - for (auto component : llvm::enumerate(recordOp.getVariables())) { - componentsMap[component.value().getSymName()] = - op.getValues()[component.index()]; - } + llvm::SmallVector componentGetOps; - for (ComponentGetOp getOp : componentGetOps) { - rewriter.replaceOp(getOp, componentsMap[getOp.getComponentName()]); - } - - return mlir::success(); + for (mlir::Operation *user : llvm::make_early_inc_range(op->getUsers())) { + if (auto getOp = mlir::dyn_cast(user)) { + componentGetOps.push_back(getOp); } - }; - - class TensorFromElementsUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - TensorFromElementsOp op, - mlir::PatternRewriter& rewriter) const override - { - mlir::Type resultType = op.getResult().getType(); - mlir::Type resultBaseType = resultType; - - if (auto tensorType = resultType.dyn_cast()) { - resultBaseType = tensorType.getElementType(); - } + } - auto recordType = resultBaseType.dyn_cast(); + if (componentGetOps.empty()) { + return mlir::failure(); + } - if (!recordType) { - return mlir::failure(); - } + llvm::StringMap componentsMap; - auto recordOp = getRecordOp(recordType); + for (auto component : llvm::enumerate(recordOp.getVariables())) { + componentsMap[component.value().getSymName()] = + op.getValues()[component.index()]; + } - llvm::SmallVector componentGetOps; + for (ComponentGetOp getOp : componentGetOps) { + rewriter.replaceOp(getOp, componentsMap[getOp.getComponentName()]); + } - for (mlir::Operation* user : - llvm::make_early_inc_range(op->getUsers())) { - if (auto getOp = mlir::dyn_cast(user)) { - componentGetOps.push_back(getOp); - } - } + return mlir::success(); + } +}; - if (componentGetOps.empty()) { - return mlir::failure(); - } +class TensorFromElementsUnpackPattern + : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - llvm::StringMap componentsMap; + mlir::LogicalResult + matchAndRewrite(TensorFromElementsOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Type resultType = op.getResult().getType(); + mlir::Type resultBaseType = resultType; - for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector componentValues; + if (auto tensorType = resultType.dyn_cast()) { + resultBaseType = tensorType.getElementType(); + } - for (mlir::Value element : op.getValues()) { - llvm::SmallVector shape; - llvm::ArrayRef elementShape = std::nullopt; + auto recordType = resultBaseType.dyn_cast(); - if (auto elementTensorType = - element.getType().dyn_cast()) { - elementShape = elementTensorType.getShape(); - } + if (!recordType) { + return mlir::failure(); + } - mergeShapes( - shape, elementShape, component.getVariableType().getShape()); + auto recordOp = getRecordOp(recordType); - auto componentGetOp = rewriter.create( - op.getLoc(), - component.getVariableType().withShape(shape).unwrap(), - element, - component.getSymName()); + llvm::SmallVector componentGetOps; - componentValues.push_back(componentGetOp); - } + for (mlir::Operation *user : llvm::make_early_inc_range(op->getUsers())) { + if (auto getOp = mlir::dyn_cast(user)) { + componentGetOps.push_back(getOp); + } + } - llvm::SmallVector shape; + if (componentGetOps.empty()) { + return mlir::failure(); + } - mergeShapes( - shape, - op.getTensor().getType().getShape(), - component.getVariableType().getShape()); + llvm::StringMap componentsMap; - auto sliceOp = rewriter.create( - op.getLoc(), - op.getTensor().getType().clone(shape).clone( - component.getVariableType().getElementType()), - componentValues); + for (VariableOp component : recordOp.getVariables()) { + llvm::SmallVector componentValues; - componentsMap[component.getSymName()] = sliceOp; - } + for (mlir::Value element : op.getValues()) { + llvm::SmallVector shape; + llvm::ArrayRef elementShape = std::nullopt; - llvm::SmallVector subscriptions; + if (auto elementTensorType = + element.getType().dyn_cast()) { + elementShape = elementTensorType.getShape(); + } - auto componentGetter = - [&](mlir::OpBuilder& builder, - mlir::Location loc, - llvm::StringRef componentName) -> mlir::Value { - return componentsMap[componentName]; - }; + mergeShapes(shape, elementShape, + component.getVariableType().getShape()); - for (mlir::Operation* user : - llvm::make_early_inc_range(op.getResult().getUsers())) { - if (mlir::failed(replaceRecordGetters( - rewriter, componentGetter, subscriptions, op.getResult(), - user))) { - return mlir::failure(); - } - } + auto componentGetOp = rewriter.create( + op.getLoc(), component.getVariableType().withShape(shape).unwrap(), + element, component.getSymName()); - return mlir::success(); + componentValues.push_back(componentGetOp); } - }; - - class TensorBroadcastUnpackPattern - : public RecordInliningPattern - { - public: - using RecordInliningPattern::RecordInliningPattern; - - mlir::LogicalResult matchAndRewrite( - TensorBroadcastOp op, - mlir::PatternRewriter& rewriter) const override - { - mlir::Type resultType = op.getResult().getType(); - mlir::Type resultBaseType = resultType; - - if (auto tensorType = resultType.dyn_cast()) { - resultBaseType = tensorType.getElementType(); - } - auto recordType = resultBaseType.dyn_cast(); + llvm::SmallVector shape; - if (!recordType) { - return mlir::failure(); - } + mergeShapes(shape, op.getTensor().getType().getShape(), + component.getVariableType().getShape()); - auto recordOp = getRecordOp(recordType); + auto sliceOp = rewriter.create( + op.getLoc(), + op.getTensor().getType().clone(shape).clone( + component.getVariableType().getElementType()), + componentValues); - llvm::SmallVector componentGetOps; + componentsMap[component.getSymName()] = sliceOp; + } - for (mlir::Operation* user : - llvm::make_early_inc_range(op->getUsers())) { - if (auto getOp = mlir::dyn_cast(user)) { - componentGetOps.push_back(getOp); - } - } + llvm::SmallVector subscriptions; - if (componentGetOps.empty()) { - return mlir::failure(); - } + auto componentGetter = [&](mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef componentName) -> mlir::Value { + return componentsMap[componentName]; + }; - llvm::StringMap componentsMap; + for (mlir::Operation *user : + llvm::make_early_inc_range(op.getResult().getUsers())) { + if (mlir::failed(replaceRecordGetters(rewriter, componentGetter, + subscriptions, op.getResult(), + user))) { + return mlir::failure(); + } + } - for (VariableOp component : recordOp.getVariables()) { - llvm::SmallVector componentValues; - mlir::Value element = op.getValue(); - llvm::SmallVector getResultShape; - llvm::ArrayRef elementShape = std::nullopt; - - if (auto elementTensorType = - element.getType().dyn_cast()) { - elementShape = elementTensorType.getShape(); - } + return mlir::success(); + } +}; - mergeShapes( - getResultShape, elementShape, component.getVariableType().getShape()); +class TensorBroadcastUnpackPattern + : public RecordInliningPattern { +public: + using RecordInliningPattern::RecordInliningPattern; - auto componentGetOp = rewriter.create( - op.getLoc(), - component.getVariableType().withShape(getResultShape).unwrap(), - element, - component.getSymName()); + mlir::LogicalResult + matchAndRewrite(TensorBroadcastOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Type resultType = op.getResult().getType(); + mlir::Type resultBaseType = resultType; - componentValues.push_back(componentGetOp); + if (auto tensorType = resultType.dyn_cast()) { + resultBaseType = tensorType.getElementType(); + } - llvm::SmallVector shape; + auto recordType = resultBaseType.dyn_cast(); - mergeShapes( - shape, - op.getTensor().getType().getShape(), - component.getVariableType().getShape()); + if (!recordType) { + return mlir::failure(); + } - auto sliceOp = rewriter.create( - op.getLoc(), - op.getTensor().getType().clone(shape).clone( - component.getVariableType().getElementType()), - componentValues); + auto recordOp = getRecordOp(recordType); - componentsMap[component.getSymName()] = sliceOp; - } + llvm::SmallVector componentGetOps; - llvm::SmallVector subscriptions; + for (mlir::Operation *user : llvm::make_early_inc_range(op->getUsers())) { + if (auto getOp = mlir::dyn_cast(user)) { + componentGetOps.push_back(getOp); + } + } - auto componentGetter = - [&](mlir::OpBuilder& builder, - mlir::Location loc, - llvm::StringRef componentName) -> mlir::Value { - return componentsMap[componentName]; - }; + if (componentGetOps.empty()) { + return mlir::failure(); + } - for (mlir::Operation* user : - llvm::make_early_inc_range(op.getResult().getUsers())) { - if (mlir::failed(replaceRecordGetters( - rewriter, componentGetter, subscriptions, op.getResult(), - user))) { - return mlir::failure(); - } - } + llvm::StringMap componentsMap; + + for (VariableOp component : recordOp.getVariables()) { + llvm::SmallVector componentValues; + mlir::Value element = op.getValue(); + llvm::SmallVector getResultShape; + llvm::ArrayRef elementShape = std::nullopt; - return mlir::success(); + if (auto elementTensorType = + element.getType().dyn_cast()) { + elementShape = elementTensorType.getShape(); } - }; - - class RecordCreateOpFoldPattern - : public mlir::OpRewritePattern - { - public: - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - RecordCreateOp op, mlir::PatternRewriter& rewriter) const override - { - if (op->use_empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } + mergeShapes(getResultShape, elementShape, + component.getVariableType().getShape()); + + auto componentGetOp = rewriter.create( + op.getLoc(), + component.getVariableType().withShape(getResultShape).unwrap(), + element, component.getSymName()); + + componentValues.push_back(componentGetOp); + + llvm::SmallVector shape; + + mergeShapes(shape, op.getTensor().getType().getShape(), + component.getVariableType().getShape()); + + auto sliceOp = rewriter.create( + op.getLoc(), + op.getTensor().getType().clone(shape).clone( + component.getVariableType().getElementType()), + componentValues); + + componentsMap[component.getSymName()] = sliceOp; + } + + llvm::SmallVector subscriptions; + + auto componentGetter = [&](mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef componentName) -> mlir::Value { + return componentsMap[componentName]; + }; + + for (mlir::Operation *user : + llvm::make_early_inc_range(op.getResult().getUsers())) { + if (mlir::failed(replaceRecordGetters(rewriter, componentGetter, + subscriptions, op.getResult(), + user))) { return mlir::failure(); } - }; -} + } + + return mlir::success(); + } +}; + +class RecordCreateOpFoldPattern + : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(RecordCreateOp op, + mlir::PatternRewriter &rewriter) const override { + if (op->use_empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + return mlir::failure(); + } +}; +} // namespace -namespace -{ - class RecordInliningPass - : public mlir::bmodelica::impl::RecordInliningPassBase< - RecordInliningPass> - { - public: - using RecordInliningPassBase::RecordInliningPassBase; +namespace { +class RecordInliningPass + : public mlir::bmodelica::impl::RecordInliningPassBase { +public: + using RecordInliningPassBase::RecordInliningPassBase; - void runOnOperation() override; + void runOnOperation() override; - mlir::LogicalResult explicitateAccesses(); + mlir::LogicalResult explicitateAccesses(); - mlir::LogicalResult unpackRecordVariables(); + mlir::LogicalResult unpackRecordVariables(); - mlir::LogicalResult foldRecordCreateOps(); - }; -} + mlir::LogicalResult foldRecordCreateOps(); +}; +} // namespace -void RecordInliningPass::runOnOperation() -{ +void RecordInliningPass::runOnOperation() { if (mlir::failed(explicitateAccesses())) { return signalPassFailure(); } @@ -1331,17 +1209,15 @@ void RecordInliningPass::runOnOperation() } } -mlir::LogicalResult RecordInliningPass::explicitateAccesses() -{ +mlir::LogicalResult RecordInliningPass::explicitateAccesses() { mlir::ModuleOp moduleOp = getOperation(); mlir::SymbolTableCollection symbolTable; mlir::RewritePatternSet patterns(&getContext()); - patterns.add< - VariableSetOpUnpackPattern, - VariableComponentSetOpUnpackPattern, - EquationSideOpUnpackPattern>(&getContext(), moduleOp, symbolTable); + patterns.add(&getContext(), moduleOp, + symbolTable); mlir::GreedyRewriteConfig config; config.useTopDownTraversal = true; @@ -1350,16 +1226,14 @@ mlir::LogicalResult RecordInliningPass::explicitateAccesses() return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), config); } -mlir::LogicalResult RecordInliningPass::unpackRecordVariables() -{ +mlir::LogicalResult RecordInliningPass::unpackRecordVariables() { mlir::ModuleOp moduleOp = getOperation(); mlir::SymbolTableCollection symbolTable; mlir::RewritePatternSet patterns(&getContext()); - patterns.add< - VariableOpUnpackPattern, - CallResultUnpackPattern>(&getContext(), moduleOp, symbolTable); + patterns.add( + &getContext(), moduleOp, symbolTable); mlir::GreedyRewriteConfig config; config.useTopDownTraversal = true; @@ -1368,17 +1242,15 @@ mlir::LogicalResult RecordInliningPass::unpackRecordVariables() return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), config); } -mlir::LogicalResult RecordInliningPass::foldRecordCreateOps() -{ +mlir::LogicalResult RecordInliningPass::foldRecordCreateOps() { mlir::ModuleOp moduleOp = getOperation(); mlir::SymbolTableCollection symbolTable; mlir::RewritePatternSet patterns(&getContext()); - patterns.add< - RecordCreateOpUnpackPattern, - TensorFromElementsUnpackPattern, - TensorBroadcastUnpackPattern>(&getContext(), moduleOp, symbolTable); + patterns.add(&getContext(), moduleOp, + symbolTable); patterns.add(&getContext()); @@ -1388,10 +1260,8 @@ mlir::LogicalResult RecordInliningPass::foldRecordCreateOps() return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), config); } -namespace mlir::bmodelica -{ - std::unique_ptr createRecordInliningPass() - { - return std::make_unique(); - } +namespace mlir::bmodelica { +std::unique_ptr createRecordInliningPass() { + return std::make_unique(); } +} // namespace mlir::bmodelica diff --git a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp index 9cb2f2d39..5788696c2 100644 --- a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp +++ b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp @@ -30,7 +30,8 @@ struct CyclicEquation { using Cycle = llvm::SmallVector; -static void printCycle(llvm::raw_ostream &os, const Cycle &cycle) { +[[maybe_unused]] static void printCycle(llvm::raw_ostream &os, + const Cycle &cycle) { for (const CyclicEquation &cyclicEquation : cycle) { os << cyclicEquation.writeAccess.getVariable() << " -> "; } diff --git a/lib/Dialect/BaseModelica/Transforms/SCCSolvingWithKINSOL.cpp b/lib/Dialect/BaseModelica/Transforms/SCCSolvingWithKINSOL.cpp index f56d3c966..6786ee1fc 100644 --- a/lib/Dialect/BaseModelica/Transforms/SCCSolvingWithKINSOL.cpp +++ b/lib/Dialect/BaseModelica/Transforms/SCCSolvingWithKINSOL.cpp @@ -591,9 +591,6 @@ mlir::LogicalResult KINSOLInstance::addEquationsToKINSOL( return mlir::failure(); } - auto writtenVar = symbolTableCollection->lookupSymbolIn( - modelOp, writeAccess->getVariable()); - // Collect the independent variables for automatic differentiation. llvm::DenseSet independentVariables; diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 47270616e..0c3d2fad9 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -35,19 +35,17 @@ using ArgumentConsumer = CompilerInvocation::ArgumentConsumer; #include "clang/Driver/Options.inc" #undef SIMPLE_ENUM_VALUE_TABLE -static std::optional normalizeSimpleFlag(OptSpecifier Opt, - unsigned TableIndex, - const ArgList &Args, - DiagnosticsEngine &Diags) { +[[maybe_unused]] static std::optional +normalizeSimpleFlag(OptSpecifier Opt, unsigned TableIndex, const ArgList &Args, + DiagnosticsEngine &Diags) { if (Args.hasArg(Opt)) return true; return std::nullopt; } -static std::optional normalizeSimpleNegativeFlag(OptSpecifier Opt, - unsigned, - const ArgList &Args, - DiagnosticsEngine &) { +[[maybe_unused]] static std::optional +normalizeSimpleNegativeFlag(OptSpecifier Opt, unsigned, const ArgList &Args, + DiagnosticsEngine &) { if (Args.hasArg(Opt)) return false; return std::nullopt; @@ -57,9 +55,10 @@ static std::optional normalizeSimpleNegativeFlag(OptSpecifier Opt, /// denormalizeSimpleFlags never looks at it. Avoid bloating compile-time with /// unnecessary template instantiations and just ignore it with a variadic /// argument. -static void denormalizeSimpleFlag(ArgumentConsumer Consumer, - const Twine &Spelling, Option::OptionClass, - unsigned, /*T*/...) { +[[maybe_unused]] static void denormalizeSimpleFlag(ArgumentConsumer Consumer, + const Twine &Spelling, + Option::OptionClass, + unsigned, /*T*/...) { Consumer(Spelling); } @@ -97,7 +96,7 @@ static auto makeBooleanOptionNormalizer(bool Value, bool OtherValue, }; } -static auto makeBooleanOptionDenormalizer(bool Value) { +[[maybe_unused]] static auto makeBooleanOptionDenormalizer(bool Value) { return [Value](ArgumentConsumer Consumer, const Twine &Spelling, Option::OptionClass, unsigned, bool KeyPath) { if (KeyPath == Value) @@ -227,11 +226,10 @@ normalizeStringVector(OptSpecifier Opt, int, const ArgList &Args, return Args.getAllArgValues(Opt); } -static void denormalizeStringVector(ArgumentConsumer Consumer, - const Twine &Spelling, - Option::OptionClass OptClass, - unsigned TableIndex, - const std::vector &Values) { +[[maybe_unused]] static void +denormalizeStringVector(ArgumentConsumer Consumer, const Twine &Spelling, + Option::OptionClass OptClass, unsigned TableIndex, + const std::vector &Values) { switch (OptClass) { case Option::CommaJoinedClass: { std::string CommaJoinedValue; diff --git a/lib/Modeling/AccessFunctionConstant.cpp b/lib/Modeling/AccessFunctionConstant.cpp index 2e66be7b4..c35cac419 100644 --- a/lib/Modeling/AccessFunctionConstant.cpp +++ b/lib/Modeling/AccessFunctionConstant.cpp @@ -22,7 +22,7 @@ bool AccessFunctionConstant::canBeBuilt( bool AccessFunctionConstant::canBeBuilt(mlir::AffineMap affineMap) { return llvm::all_of(affineMap.getResults(), [](mlir::AffineExpr expression) { - return expression.isa(); + return mlir::isa(expression); }); } diff --git a/lib/Modeling/DimensionAccess.cpp b/lib/Modeling/DimensionAccess.cpp index aa5a2b5c5..22b1cc617 100644 --- a/lib/Modeling/DimensionAccess.cpp +++ b/lib/Modeling/DimensionAccess.cpp @@ -9,226 +9,195 @@ using namespace ::marco::modeling; -namespace marco::modeling -{ - DimensionAccess::Redirect::Redirect() - { - } +namespace marco::modeling { +DimensionAccess::Redirect::Redirect() {} - DimensionAccess::Redirect::Redirect( - std::unique_ptr dimensionAccess) - : dimensionAccess(std::move(dimensionAccess)) - { - } +DimensionAccess::Redirect::Redirect( + std::unique_ptr dimensionAccess) + : dimensionAccess(std::move(dimensionAccess)) {} - DimensionAccess::Redirect::Redirect(const Redirect& other) - : dimensionAccess(other.dimensionAccess->clone()) - { - } +DimensionAccess::Redirect::Redirect(const Redirect &other) + : dimensionAccess(other.dimensionAccess->clone()) {} - DimensionAccess::Redirect::Redirect( - DimensionAccess::Redirect&& other) = default; +DimensionAccess::Redirect::Redirect(DimensionAccess::Redirect &&other) = + default; - DimensionAccess::Redirect::~Redirect() = default; +DimensionAccess::Redirect::~Redirect() = default; - DimensionAccess::Redirect& DimensionAccess::Redirect::operator=( - const DimensionAccess::Redirect& other) - { - DimensionAccess::Redirect result(other); - swap(*this, result); - return *this; - } +DimensionAccess::Redirect & +DimensionAccess::Redirect::operator=(const DimensionAccess::Redirect &other) { + DimensionAccess::Redirect result(other); + swap(*this, result); + return *this; +} - DimensionAccess::Redirect& DimensionAccess::Redirect::operator=( - DimensionAccess::Redirect&& other) = default; +DimensionAccess::Redirect &DimensionAccess::Redirect::operator=( + DimensionAccess::Redirect &&other) = default; - void swap( - DimensionAccess::Redirect& first, - DimensionAccess::Redirect& second) - { - using std::swap; - swap(first.dimensionAccess, second.dimensionAccess); - } +void swap(DimensionAccess::Redirect &first, DimensionAccess::Redirect &second) { + using std::swap; + swap(first.dimensionAccess, second.dimensionAccess); +} - bool DimensionAccess::Redirect::operator==(const Redirect& other) const - { - return *dimensionAccess == *other.dimensionAccess; - } +bool DimensionAccess::Redirect::operator==(const Redirect &other) const { + return *dimensionAccess == *other.dimensionAccess; +} - bool DimensionAccess::Redirect::operator!=(const Redirect& other) const - { - return !(*this == other); - } +bool DimensionAccess::Redirect::operator!=(const Redirect &other) const { + return !(*this == other); +} + +const DimensionAccess &DimensionAccess::Redirect::operator*() const { + assert(dimensionAccess && "Dimension access not set"); + return *dimensionAccess; +} + +const DimensionAccess *DimensionAccess::Redirect::operator->() const { + assert(dimensionAccess && "Dimension access not set"); + return dimensionAccess.get(); +} - const DimensionAccess& DimensionAccess::Redirect::operator*() const - { - assert(dimensionAccess && "Dimension access not set"); - return *dimensionAccess; +std::unique_ptr +DimensionAccess::build(mlir::AffineExpr expression) { + if (auto constantExpr = + mlir::dyn_cast(expression)) { + return std::make_unique(constantExpr.getContext(), + constantExpr.getValue()); } - const DimensionAccess* DimensionAccess::Redirect::operator->() const - { - assert(dimensionAccess && "Dimension access not set"); - return dimensionAccess.get(); + if (auto dimExpr = mlir::dyn_cast(expression)) { + return std::make_unique(dimExpr.getContext(), + dimExpr.getPosition()); } - std::unique_ptr DimensionAccess::build( - mlir::AffineExpr expression) - { - if (auto constantExpr = expression.dyn_cast()) { - return std::make_unique( - constantExpr.getContext(), constantExpr.getValue()); + if (auto binaryExpr = mlir::dyn_cast(expression)) { + auto kind = binaryExpr.getKind(); + + if (kind == mlir::AffineExprKind::Add) { + return std::make_unique( + binaryExpr.getContext(), DimensionAccess::build(binaryExpr.getLHS()), + DimensionAccess::build(binaryExpr.getRHS())); } - if (auto dimExpr = expression.dyn_cast()) { - return std::make_unique( - dimExpr.getContext(), dimExpr.getPosition()); + if (kind == mlir::AffineExprKind::Mul) { + return std::make_unique( + binaryExpr.getContext(), DimensionAccess::build(binaryExpr.getLHS()), + DimensionAccess::build(binaryExpr.getRHS())); } - if (auto binaryExpr = expression.dyn_cast()) { - auto kind = binaryExpr.getKind(); - - if (kind == mlir::AffineExprKind::Add) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::build(binaryExpr.getLHS()), - DimensionAccess::build(binaryExpr.getRHS())); - } - - if (kind == mlir::AffineExprKind::Mul) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::build(binaryExpr.getLHS()), - DimensionAccess::build(binaryExpr.getRHS())); - } - - if (kind == mlir::AffineExprKind::FloorDiv) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::build(binaryExpr.getLHS()), - DimensionAccess::build(binaryExpr.getRHS())); - } + if (kind == mlir::AffineExprKind::FloorDiv) { + return std::make_unique( + binaryExpr.getContext(), DimensionAccess::build(binaryExpr.getLHS()), + DimensionAccess::build(binaryExpr.getRHS())); } + } - llvm_unreachable("Unexpected expression type"); - return nullptr; + llvm_unreachable("Unexpected expression type"); + return nullptr; +} + +std::unique_ptr +DimensionAccess::getDimensionAccessFromExtendedMap( + mlir::AffineExpr expression, + const DimensionAccess::FakeDimensionsMap &fakeDimensionsMap) { + if (auto constantExpr = + mlir::dyn_cast(expression)) { + return std::make_unique(constantExpr.getContext(), + constantExpr.getValue()); } - std::unique_ptr - DimensionAccess::getDimensionAccessFromExtendedMap( - mlir::AffineExpr expression, - const DimensionAccess::FakeDimensionsMap& fakeDimensionsMap) - { - if (auto constantExpr = expression.dyn_cast()) { - return std::make_unique( - constantExpr.getContext(), constantExpr.getValue()); + if (auto dimExpr = mlir::dyn_cast(expression)) { + auto fakeDimReplacementIt = fakeDimensionsMap.find(dimExpr.getPosition()); + + if (fakeDimReplacementIt != fakeDimensionsMap.end()) { + return fakeDimReplacementIt->getSecond()->clone(); } - if (auto dimExpr = expression.dyn_cast()) { - auto fakeDimReplacementIt = fakeDimensionsMap.find(dimExpr.getPosition()); + return std::make_unique(dimExpr.getContext(), + dimExpr.getPosition()); + } - if (fakeDimReplacementIt != fakeDimensionsMap.end()) { - return fakeDimReplacementIt->getSecond()->clone(); - } + if (auto binaryExpr = mlir::dyn_cast(expression)) { + auto kind = binaryExpr.getKind(); - return std::make_unique( - dimExpr.getContext(), dimExpr.getPosition()); + if (kind == mlir::AffineExprKind::Add) { + return std::make_unique( + binaryExpr.getContext(), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getLHS(), fakeDimensionsMap), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getRHS(), fakeDimensionsMap)); } - if (auto binaryExpr = expression.dyn_cast()) { - auto kind = binaryExpr.getKind(); - - if (kind == mlir::AffineExprKind::Add) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getLHS(), fakeDimensionsMap), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getRHS(), fakeDimensionsMap)); - } - - if (kind == mlir::AffineExprKind::Mul) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getLHS(), fakeDimensionsMap), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getRHS(), fakeDimensionsMap)); - } - - if (kind == mlir::AffineExprKind::FloorDiv) { - return std::make_unique( - binaryExpr.getContext(), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getLHS(), fakeDimensionsMap), - DimensionAccess::getDimensionAccessFromExtendedMap( - binaryExpr.getRHS(), fakeDimensionsMap)); - } + if (kind == mlir::AffineExprKind::Mul) { + return std::make_unique( + binaryExpr.getContext(), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getLHS(), fakeDimensionsMap), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getRHS(), fakeDimensionsMap)); } - llvm_unreachable("Unexpected expression type"); - return nullptr; + if (kind == mlir::AffineExprKind::FloorDiv) { + return std::make_unique( + binaryExpr.getContext(), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getLHS(), fakeDimensionsMap), + DimensionAccess::getDimensionAccessFromExtendedMap( + binaryExpr.getRHS(), fakeDimensionsMap)); + } } - DimensionAccess::DimensionAccess(Kind kind, mlir::MLIRContext* context) - : kind(kind), - context(context) - { - } + llvm_unreachable("Unexpected expression type"); + return nullptr; +} - DimensionAccess::DimensionAccess(const DimensionAccess& other) = default; +DimensionAccess::DimensionAccess(Kind kind, mlir::MLIRContext *context) + : kind(kind), context(context) {} - DimensionAccess::~DimensionAccess() = default; +DimensionAccess::DimensionAccess(const DimensionAccess &other) = default; - void swap(DimensionAccess& first, DimensionAccess& second) - { - using std::swap; - swap(first.kind, second.kind); - swap(first.context, second.context); - } +DimensionAccess::~DimensionAccess() = default; - std::unique_ptr DimensionAccess::operator+( - const DimensionAccess& other) const - { - return std::make_unique( - getContext(), this->clone(), other.clone()); - } +void swap(DimensionAccess &first, DimensionAccess &second) { + using std::swap; + swap(first.kind, second.kind); + swap(first.context, second.context); +} - std::unique_ptr DimensionAccess::operator-( - const DimensionAccess& other) const - { - return std::make_unique( - getContext(), this->clone(), other.clone()); - } +std::unique_ptr +DimensionAccess::operator+(const DimensionAccess &other) const { + return std::make_unique(getContext(), this->clone(), + other.clone()); +} - std::unique_ptr DimensionAccess::operator*( - const DimensionAccess& other) const - { - return std::make_unique( - getContext(), this->clone(), other.clone()); - } +std::unique_ptr +DimensionAccess::operator-(const DimensionAccess &other) const { + return std::make_unique(getContext(), this->clone(), + other.clone()); +} - std::unique_ptr DimensionAccess::operator/( - const DimensionAccess& other) const - { - return std::make_unique( - getContext(), this->clone(), other.clone()); - } +std::unique_ptr +DimensionAccess::operator*(const DimensionAccess &other) const { + return std::make_unique(getContext(), this->clone(), + other.clone()); +} - mlir::MLIRContext* DimensionAccess::getContext() const - { - assert(context && "MLIR context not set"); - return context; - } +std::unique_ptr +DimensionAccess::operator/(const DimensionAccess &other) const { + return std::make_unique(getContext(), this->clone(), + other.clone()); +} - bool DimensionAccess::isAffine() const - { - return false; - } +mlir::MLIRContext *DimensionAccess::getContext() const { + assert(context && "MLIR context not set"); + return context; +} - mlir::AffineExpr DimensionAccess::getAffineExpr() const - { - llvm_unreachable("Not an affine expression"); - return nullptr; - } +bool DimensionAccess::isAffine() const { return false; } + +mlir::AffineExpr DimensionAccess::getAffineExpr() const { + llvm_unreachable("Not an affine expression"); + return nullptr; } +} // namespace marco::modeling diff --git a/lib/Modeling/DimensionAccessConstant.cpp b/lib/Modeling/DimensionAccessConstant.cpp index 6af9232b1..d4e54c7be 100644 --- a/lib/Modeling/DimensionAccessConstant.cpp +++ b/lib/Modeling/DimensionAccessConstant.cpp @@ -3,130 +3,73 @@ using namespace ::marco::modeling; -namespace marco::modeling -{ - DimensionAccessConstant::DimensionAccessConstant( - mlir::MLIRContext* context, int64_t value) - : DimensionAccess(DimensionAccess::Kind::Constant, context), - value(value) - { - } - - DimensionAccessConstant::DimensionAccessConstant( - const DimensionAccessConstant& other) - : DimensionAccess(other), - value(other.value) - { - } - - DimensionAccessConstant::DimensionAccessConstant( - DimensionAccessConstant&& other) noexcept = default; - - DimensionAccessConstant::~DimensionAccessConstant() = default; - - DimensionAccessConstant& DimensionAccessConstant::operator=( - const DimensionAccessConstant& other) - { - DimensionAccessConstant result(other); - swap(*this, result); - return *this; - } - - DimensionAccessConstant& DimensionAccessConstant::operator=( - DimensionAccessConstant&& other) noexcept = default; - - void swap(DimensionAccessConstant& first, DimensionAccessConstant& second) - { - using std::swap; +namespace marco::modeling { +DimensionAccessConstant::DimensionAccessConstant(mlir::MLIRContext *context, + int64_t value) + : DimensionAccess(DimensionAccess::Kind::Constant, context), value(value) {} - swap(static_cast(first), - static_cast(second)); - - swap(first.value, second.value); - } - - std::unique_ptr DimensionAccessConstant::clone() const - { - return std::make_unique(*this); - } - - bool DimensionAccessConstant::operator==(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this == *otherCasted; - } +std::unique_ptr DimensionAccessConstant::clone() const { + return std::make_unique(*this); +} - return false; +bool DimensionAccessConstant::operator==(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this == *otherCasted; } - bool DimensionAccessConstant::operator==( - const DimensionAccessConstant& other) const - { - return getValue() == other.getValue(); - } + return false; +} - bool DimensionAccessConstant::operator!=(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this != *otherCasted; - } +bool DimensionAccessConstant::operator==( + const DimensionAccessConstant &other) const { + return getValue() == other.getValue(); +} - return true; +bool DimensionAccessConstant::operator!=(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this != *otherCasted; } - bool DimensionAccessConstant::operator!=( - const DimensionAccessConstant& other) const - { - return getValue() != other.getValue(); - } + return true; +} - llvm::raw_ostream& DimensionAccessConstant::dump( - llvm::raw_ostream& os, - const llvm::DenseMap< - const IndexSet*, uint64_t>& iterationSpacesIds) const - { - return os << getValue(); - } +bool DimensionAccessConstant::operator!=( + const DimensionAccessConstant &other) const { + return getValue() != other.getValue(); +} - void DimensionAccessConstant::collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const - { - } +llvm::raw_ostream & +DimensionAccessConstant::dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const { + return os << getValue(); +} - void DimensionAccessConstant::collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const - { - } +void DimensionAccessConstant::collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const {} - bool DimensionAccessConstant::isAffine() const - { - return true; - } +void DimensionAccessConstant::collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const {} - mlir::AffineExpr DimensionAccessConstant::getAffineExpr() const - { - return mlir::getAffineConstantExpr(getValue(), getContext()); - } +bool DimensionAccessConstant::isAffine() const { return true; } - mlir::AffineExpr DimensionAccessConstant::getAffineExpr( - unsigned int numOfDimensions, - DimensionAccess::FakeDimensionsMap& fakeDimensionsMap) const - { - return mlir::getAffineConstantExpr(getValue(), getContext()); - } +mlir::AffineExpr DimensionAccessConstant::getAffineExpr() const { + return mlir::getAffineConstantExpr(getValue(), getContext()); +} - IndexSet DimensionAccessConstant::map( - const Point& point, - llvm::DenseMap& currentIndexSetsPoint) const - { - return {Point(getValue())}; - } +mlir::AffineExpr DimensionAccessConstant::getAffineExpr( + unsigned int numOfDimensions, + DimensionAccess::FakeDimensionsMap &fakeDimensionsMap) const { + return mlir::getAffineConstantExpr(getValue(), getContext()); +} - int64_t DimensionAccessConstant::getValue() const - { - return value; - } +IndexSet DimensionAccessConstant::map( + const Point &point, + llvm::DenseMap ¤tIndexSetsPoint) const { + return {Point(getValue())}; } + +int64_t DimensionAccessConstant::getValue() const { return value; } +} // namespace marco::modeling diff --git a/lib/Modeling/DimensionAccessDimension.cpp b/lib/Modeling/DimensionAccessDimension.cpp index b1181f0de..20c940d5a 100644 --- a/lib/Modeling/DimensionAccessDimension.cpp +++ b/lib/Modeling/DimensionAccessDimension.cpp @@ -3,126 +3,74 @@ using namespace ::marco::modeling; -namespace marco::modeling -{ - DimensionAccessDimension::DimensionAccessDimension( - mlir::MLIRContext* context, uint64_t dimension) - : DimensionAccess(DimensionAccess::Kind::Dimension, context), - dimension(dimension) - { - } - - DimensionAccessDimension::DimensionAccessDimension( - const DimensionAccessDimension& other) = default; - - DimensionAccessDimension::DimensionAccessDimension( - DimensionAccessDimension&& other) noexcept = default; - - DimensionAccessDimension::~DimensionAccessDimension() = default; - - DimensionAccessDimension& DimensionAccessDimension::operator=( - const DimensionAccessDimension& other) - { - DimensionAccessDimension result(other); - swap(*this, result); - return *this; - } - - DimensionAccessDimension& DimensionAccessDimension::operator=( - DimensionAccessDimension&& other) noexcept = default; - - void swap(DimensionAccessDimension& first, DimensionAccessDimension& second) - { - using std::swap; - - swap(static_cast(first), - static_cast(second)); - - swap(first.dimension, second.dimension); - } - - std::unique_ptr DimensionAccessDimension::clone() const - { - return std::make_unique(*this); - } - - bool DimensionAccessDimension::operator==(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this == *otherCasted; - } +namespace marco::modeling { +DimensionAccessDimension::DimensionAccessDimension(mlir::MLIRContext *context, + uint64_t dimension) + : DimensionAccess(DimensionAccess::Kind::Dimension, context), + dimension(dimension) {} + +std::unique_ptr DimensionAccessDimension::clone() const { + return std::make_unique(*this); +} - return false; +bool DimensionAccessDimension::operator==(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this == *otherCasted; } - bool DimensionAccessDimension::operator==( - const DimensionAccessDimension& other) const - { - return getDimension() == other.getDimension(); - } + return false; +} - bool DimensionAccessDimension::operator!=(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this != *otherCasted; - } +bool DimensionAccessDimension::operator==( + const DimensionAccessDimension &other) const { + return getDimension() == other.getDimension(); +} - return true; +bool DimensionAccessDimension::operator!=(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this != *otherCasted; } - llvm::raw_ostream& DimensionAccessDimension::dump( - llvm::raw_ostream& os, - const llvm::DenseMap< - const IndexSet*, uint64_t>& iterationSpacesIds) const - { - return os << "d" << getDimension(); - } + return true; +} - bool DimensionAccessDimension::operator!=( - const DimensionAccessDimension& other) const - { - return getDimension() != other.getDimension(); - } +llvm::raw_ostream & +DimensionAccessDimension::dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const { + return os << "d" << getDimension(); +} - void DimensionAccessDimension::collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const - { - } +bool DimensionAccessDimension::operator!=( + const DimensionAccessDimension &other) const { + return getDimension() != other.getDimension(); +} - void DimensionAccessDimension::collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const - { - } +void DimensionAccessDimension::collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const {} - bool DimensionAccessDimension::isAffine() const - { - return true; - } +void DimensionAccessDimension::collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const {} - mlir::AffineExpr DimensionAccessDimension::getAffineExpr() const - { - return mlir::getAffineDimExpr(getDimension(), getContext()); - } +bool DimensionAccessDimension::isAffine() const { return true; } - mlir::AffineExpr DimensionAccessDimension::getAffineExpr( - unsigned int numOfDimensions, - DimensionAccess::FakeDimensionsMap& fakeDimensionsMap) const - { - return mlir::getAffineDimExpr(getDimension(), getContext()); - } +mlir::AffineExpr DimensionAccessDimension::getAffineExpr() const { + return mlir::getAffineDimExpr(getDimension(), getContext()); +} - IndexSet DimensionAccessDimension::map( - const Point& point, - llvm::DenseMap& currentIndexSetsPoint) const - { - return {Point(point[getDimension()])}; - } +mlir::AffineExpr DimensionAccessDimension::getAffineExpr( + unsigned int numOfDimensions, + DimensionAccess::FakeDimensionsMap &fakeDimensionsMap) const { + return mlir::getAffineDimExpr(getDimension(), getContext()); +} - uint64_t DimensionAccessDimension::getDimension() const - { - return dimension; - } +IndexSet DimensionAccessDimension::map( + const Point &point, + llvm::DenseMap ¤tIndexSetsPoint) const { + return {Point(point[getDimension()])}; } + +uint64_t DimensionAccessDimension::getDimension() const { return dimension; } +} // namespace marco::modeling diff --git a/lib/Modeling/DimensionAccessIndices.cpp b/lib/Modeling/DimensionAccessIndices.cpp index 2974e3922..c8f093434 100644 --- a/lib/Modeling/DimensionAccessIndices.cpp +++ b/lib/Modeling/DimensionAccessIndices.cpp @@ -3,161 +3,107 @@ using namespace ::marco::modeling; -namespace marco::modeling -{ - DimensionAccessIndices::DimensionAccessIndices( - mlir::MLIRContext* context, - std::shared_ptr space, - uint64_t dimension, - llvm::DenseSet dimensionDependencies) - : DimensionAccess(DimensionAccess::Kind::Indices, context), - space(space), - dimension(dimension), - dimensionDependencies(std::move(dimensionDependencies)) - { - assert(dimension < space->rank()); - } - - DimensionAccessIndices::DimensionAccessIndices( - const DimensionAccessIndices& other) = default; - - DimensionAccessIndices::DimensionAccessIndices( - DimensionAccessIndices&& other) noexcept = default; - - DimensionAccessIndices::~DimensionAccessIndices() = default; - - DimensionAccessIndices& DimensionAccessIndices::operator=( - const DimensionAccessIndices& other) - { - DimensionAccessIndices result(other); - swap(*this, result); - return *this; - } - - DimensionAccessIndices& DimensionAccessIndices::operator=( - DimensionAccessIndices&& other) noexcept = default; - - void swap(DimensionAccessIndices& first, DimensionAccessIndices& second) - { - using std::swap; - - swap(static_cast(first), - static_cast(second)); +namespace marco::modeling { +DimensionAccessIndices::DimensionAccessIndices( + mlir::MLIRContext *context, std::shared_ptr space, + uint64_t dimension, llvm::DenseSet dimensionDependencies) + : DimensionAccess(DimensionAccess::Kind::Indices, context), space(space), + dimension(dimension), + dimensionDependencies(std::move(dimensionDependencies)) { + assert(dimension < space->rank()); +} - swap(first.space, second.space); - swap(first.dimension, second.dimension); - swap(first.dimensionDependencies, second.dimensionDependencies); - } +std::unique_ptr DimensionAccessIndices::clone() const { + return std::make_unique(*this); +} - std::unique_ptr DimensionAccessIndices::clone() const - { - return std::make_unique(*this); +bool DimensionAccessIndices::operator==(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this == *otherCasted; } - bool DimensionAccessIndices::operator==(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this == *otherCasted; - } + return false; +} - return false; - } +bool DimensionAccessIndices::operator==( + const DimensionAccessIndices &other) const { + return space == other.space && dimension == other.dimension; +} - bool DimensionAccessIndices::operator==( - const DimensionAccessIndices& other) const - { - return space == other.space && dimension == other.dimension; +bool DimensionAccessIndices::operator!=(const DimensionAccess &other) const { + if (auto otherCasted = other.dyn_cast()) { + return *this != *otherCasted; } - bool DimensionAccessIndices::operator!=(const DimensionAccess& other) const - { - if (auto otherCasted = other.dyn_cast()) { - return *this != *otherCasted; - } - - return true; - } + return true; +} - bool DimensionAccessIndices::operator!=( - const DimensionAccessIndices& other) const - { - return !(*this == other); - } +bool DimensionAccessIndices::operator!=( + const DimensionAccessIndices &other) const { + return !(*this == other); +} - llvm::raw_ostream& DimensionAccessIndices::dump( - llvm::raw_ostream& os, - const llvm::DenseMap< - const IndexSet*, uint64_t>& iterationSpacesIds) const - { - auto it = iterationSpacesIds.find(space.get()); - assert(it != iterationSpacesIds.end()); - return os << "e" << it->getSecond() << "[" << dimension << "]"; - } +llvm::raw_ostream & +DimensionAccessIndices::dump(llvm::raw_ostream &os, + const llvm::DenseMap + &iterationSpacesIds) const { + auto it = iterationSpacesIds.find(space.get()); + assert(it != iterationSpacesIds.end()); + return os << "e" << it->getSecond() << "[" << dimension << "]"; +} - void DimensionAccessIndices::collectIterationSpaces( - llvm::DenseSet& iterationSpaces) const - { - iterationSpaces.insert(space.get()); - } +void DimensionAccessIndices::collectIterationSpaces( + llvm::DenseSet &iterationSpaces) const { + iterationSpaces.insert(space.get()); +} - void DimensionAccessIndices::collectIterationSpaces( - llvm::SmallVectorImpl& iterationSpaces, - llvm::DenseMap< - const IndexSet*, - llvm::DenseSet>& dependentDimensions) const - { - iterationSpaces.push_back(space.get()); +void DimensionAccessIndices::collectIterationSpaces( + llvm::SmallVectorImpl &iterationSpaces, + llvm::DenseMap> + &dependentDimensions) const { + iterationSpaces.push_back(space.get()); - if (!dimensionDependencies.empty()) { - dependentDimensions[space.get()].insert(dimension); + if (!dimensionDependencies.empty()) { + dependentDimensions[space.get()].insert(dimension); - dependentDimensions[space.get()].insert( - dimensionDependencies.begin(), dimensionDependencies.end()); - } + dependentDimensions[space.get()].insert(dimensionDependencies.begin(), + dimensionDependencies.end()); } +} - mlir::AffineExpr DimensionAccessIndices::getAffineExpr( - unsigned int numOfDimensions, - DimensionAccess::FakeDimensionsMap& fakeDimensionsMap) const - { - unsigned int numOfFakeDimensions = fakeDimensionsMap.size(); - - fakeDimensionsMap[numOfDimensions + numOfFakeDimensions] = - Redirect(clone()); +mlir::AffineExpr DimensionAccessIndices::getAffineExpr( + unsigned int numOfDimensions, + DimensionAccess::FakeDimensionsMap &fakeDimensionsMap) const { + unsigned int numOfFakeDimensions = fakeDimensionsMap.size(); - return mlir::getAffineDimExpr( - numOfDimensions + numOfFakeDimensions, getContext()); - } + fakeDimensionsMap[numOfDimensions + numOfFakeDimensions] = Redirect(clone()); - IndexSet DimensionAccessIndices::map( - const Point& point, - llvm::DenseMap& currentIndexSetsPoint) const - { - IndexSet allIndices = getIndices(); + return mlir::getAffineDimExpr(numOfDimensions + numOfFakeDimensions, + getContext()); +} - if (dimensionDependencies.empty()) { - IndexSet result; +IndexSet DimensionAccessIndices::map( + const Point &point, + llvm::DenseMap ¤tIndexSetsPoint) const { + IndexSet allIndices = getIndices(); - for (const MultidimensionalRange& range : llvm::make_range( - allIndices.rangesBegin(), allIndices.rangesEnd())) { - result += MultidimensionalRange(range[dimension]); - } + if (dimensionDependencies.empty()) { + IndexSet result; - return result; + for (const MultidimensionalRange &range : + llvm::make_range(allIndices.rangesBegin(), allIndices.rangesEnd())) { + result += MultidimensionalRange(range[dimension]); } - auto pointIt = currentIndexSetsPoint.find(space.get()); - assert(pointIt != currentIndexSetsPoint.end()); - return {Point(pointIt->getSecond()[dimension])}; + return result; } - IndexSet& DimensionAccessIndices::getIndices() - { - return *space; - } - - const IndexSet& DimensionAccessIndices::getIndices() const - { - return *space; - } + auto pointIt = currentIndexSetsPoint.find(space.get()); + assert(pointIt != currentIndexSetsPoint.end()); + return {Point(pointIt->getSecond()[dimension])}; } + +IndexSet &DimensionAccessIndices::getIndices() { return *space; } + +const IndexSet &DimensionAccessIndices::getIndices() const { return *space; } +} // namespace marco::modeling