Skip to content

Commit

Permalink
NoNan transform dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 25, 2025
1 parent 00a49b0 commit b54a71d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7219,6 +7219,14 @@ void mlir::transform::addIotaSimplify(RewritePatternSet &patterns,
patterns.insert<IotaSimplify>(maxConstantExpansion, &context, benefit);
}

void mlir::transform::addNoNanAddSubSimplify(RewritePatternSet &patterns,
bool allowOnFloatingPointMath,
MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<NoNanAddSubSimplify>(allowOnFloatingPointMath, &context,
benefit);
}

void mlir::transform::addBroadcastInDimSimplify(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context,
Expand Down
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class RewritePatternSet;
namespace mlir::transform {
void addPadDotGeneral(RewritePatternSet &patterns, bool postPad,
MLIRContext &context, PatternBenefit benefit);
void addNoNanAddSubSimplify(RewritePatternSet &patterns,
bool allowOnFloatingPointMath, MLIRContext &context,
PatternBenefit benefit);
void addIotaSimplify(RewritePatternSet &patterns, int64_t maxConstantExpansion,
MLIRContext &context, PatternBenefit benefit);
void addBroadcastInDimSimplify(RewritePatternSet &patterns,
Expand Down
16 changes: 14 additions & 2 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,21 @@ def ApplyNoNanSelfSubSimplify : EnzymeHLOPatternOp<
"no_nan_self_sub_simplify"> {
let patterns = ["NoNanSelfSubSimplify"];
}
def ApplyNoNanAddSubSimplify : EnzymeHLOPatternOp<

// benefit 65k + max_constant_expansion flag
def ApplyNoNanAddSubSimplify : EnzymeHLOParameterizedPatternOp<
"no_nan_add_sub_simplify"> {
let patterns = ["NoNanAddSubSimplify"];
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
let assemblyFormat = "attr-dict";
// TODO: this should be made better searchable.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::DictionaryAttr>
static getPossibleAttrCombinations(::mlir::Builder &builder) {
return {builder.getDictionaryAttr(
builder.getNamedAttr("parameter",
builder.getBoolAttr(true)))};
}
}];
}

def ApplyConcatPushBinopAddPatterns : EnzymeHLOPatternOp<
Expand Down

0 comments on commit b54a71d

Please sign in to comment.