Skip to content

Commit

Permalink
Add min, max
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 21, 2024
1 parent 798bbf0 commit fc713f2
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 0 deletions.
42 changes: 42 additions & 0 deletions ptx_parser/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,24 @@ gen::generate_instruction_type!(
src2: T,
}
},
Min {
type: { Type::from(data.type_()) },
data: MinMaxDetails,
arguments<T>: {
dst: T,
src1: T,
src2: T,
}
},
Max {
type: { Type::from(data.type_()) },
data: MinMaxDetails,
arguments<T>: {
dst: T,
src1: T,
src2: T,
}
},
Trap { }
}
);
Expand Down Expand Up @@ -1075,3 +1093,27 @@ impl MadDetails {
}
}
}

#[derive(Copy, Clone)]
pub enum MinMaxDetails {
Signed(ScalarType),
Unsigned(ScalarType),
Float(MinMaxFloat),
}

impl MinMaxDetails {
pub fn type_(&self) -> ScalarType {
match self {
MinMaxDetails::Signed(t) => *t,
MinMaxDetails::Unsigned(t) => *t,
MinMaxDetails::Float(float) => float.type_,
}
}
}

#[derive(Copy, Clone)]
pub struct MinMaxFloat {
pub flush_to_zero: Option<bool>,
pub nan: bool,
pub type_: ScalarType,
}
210 changes: 210 additions & 0 deletions ptx_parser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,216 @@ derive_parser!(
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
min.atype d, a, b => {
ast::Instruction::Min {
data: if atype.kind() == ast::ScalarKind::Signed {
ast::MinMaxDetails::Signed(atype)
} else {
ast::MinMaxDetails::Unsigned(atype)
},
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
//min{.relu}.btype d, a, b => { todo!() }
min.btype d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Signed(btype),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
.atype: ScalarType = { .u16, .u32, .u64,
.u16x2, .s16, .s64 };
.btype: ScalarType = { .s16x2, .s32 };

//min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b;
min{.ftz}{.NaN}.f32 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f32
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
min.f64 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan: false,
type_: f64
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
ScalarType = { .f32, .f64 };

//min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b;
//min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b;
//min{.NaN}{.xorsign.abs}.bf16 d, a, b;
//min{.NaN}{.xorsign.abs}.bf16x2 d, a, b;
min{.ftz}{.NaN}.f16 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f16
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
min{.ftz}{.NaN}.f16x2 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f16x2
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
min{.NaN}.bf16 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan,
type_: bf16
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
min{.NaN}.bf16x2 d, a, b => {
ast::Instruction::Min {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan,
type_: bf16x2
}
),
arguments: MinArgs { dst: d, src1: a, src2: b }
}
}
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
max.atype d, a, b => {
ast::Instruction::Max {
data: if atype.kind() == ast::ScalarKind::Signed {
ast::MinMaxDetails::Signed(atype)
} else {
ast::MinMaxDetails::Unsigned(atype)
},
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
//max{.relu}.btype d, a, b => { todo!() }
max.btype d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Signed(btype),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
.atype: ScalarType = { .u16, .u32, .u64,
.u16x2, .s16, .s64 };
.btype: ScalarType = { .s16x2, .s32 };

//max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b;
max{.ftz}{.NaN}.f32 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f32
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
max.f64 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan: false,
type_: f64
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
ScalarType = { .f32, .f64 };

//max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b;
//max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b;
//max{.NaN}{.xorsign.abs}.bf16 d, a, b;
//max{.NaN}{.xorsign.abs}.bf16x2 d, a, b;
max{.ftz}{.NaN}.f16 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f16
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
max{.ftz}{.NaN}.f16x2 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: Some(ftz),
nan,
type_: f16x2
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
max{.NaN}.bf16 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan,
type_: bf16
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
max{.NaN}.bf16x2 d, a, b => {
ast::Instruction::Max {
data: ast::MinMaxDetails::Float(
MinMaxFloat {
flush_to_zero: None,
nan,
type_: bf16x2
}
),
arguments: MaxArgs { dst: d, src1: a, src2: b }
}
}
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
Expand Down

0 comments on commit fc713f2

Please sign in to comment.