Skip to content

Commit

Permalink
[CUDA] Implement Clamp + Remainder (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 12, 2024
1 parent d94a07a commit c3ef475
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 32 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ impl CudaCompiler {
}
gpu::Operator::Ceil(op) => instructions.push(Instruction::Ceil(self.compile_unary(op))),
gpu::Operator::Remainder(op) => {
instructions.push(Instruction::Modulo(self.compile_binary(op)))
instructions.push(Instruction::Remainder(self.compile_binary(op)))
}
gpu::Operator::Fma(op) => instructions.push(Instruction::Fma {
a: self.compile_variable(op.a),
Expand Down
24 changes: 12 additions & 12 deletions crates/cubecl-cuda/src/compiler/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ pub trait Binary {

let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
let [lhs, rhs, out] = optimized.args;
let (is_optimized, index) = match optimized.optimization_factor {
Some(factor) => (true, index / factor),
None => (false, index),
let index = match optimized.optimization_factor {
Some(factor) => index / factor,
None => index,
};

for i in 0..index {
let lhsi = lhs.index(i, is_optimized);
let rhsi = rhs.index(i, is_optimized);
let outi = out.index(i, is_optimized);
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);
let outi = out.index(i);

Self::format_scalar(f, lhsi, rhsi, outi, elem)?;
}
Expand Down Expand Up @@ -212,8 +212,8 @@ impl Binary for IndexAssign {
}

for i in 0..index {
let lhsi = lhs.index(i, lhs.item().is_optimized());
let rhsi = rhs.index(i, rhs.item().is_optimized());
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);
Self::format_scalar(f, lhsi, rhsi, *out, elem)?;
}

Expand Down Expand Up @@ -352,8 +352,8 @@ impl IndexVector {
}
};

let out = out.index(index, false);
let lhs = lhs.index(index, false);
let out = out.index(index);
let lhs = lhs.index(index);

f.write_fmt(format_args!("{out} = {lhs};\n"))
}
Expand All @@ -374,8 +374,8 @@ impl IndexAssignVector {
}
};

let out = out.index(index, false);
let rhs = rhs.index(index, false);
let out = out.index(index);
let rhs = rhs.index(index);

f.write_fmt(format_args!("{out} = {rhs};\n"))
}
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ impl Component for IndexedVariable {
}

fn index(&self, index: usize) -> IndexedVariable {
self.var.index(index, self.var.is_optimized())
self.var.index(index)
}
}
impl Component for Variable {
fn index(&self, index: usize) -> IndexedVariable {
self.index(index, self.is_optimized())
self.index(index)
}

fn item(&self) -> Item {
Expand Down Expand Up @@ -361,11 +361,11 @@ impl Variable {
}
}

pub fn index(&self, index: usize, optimized: bool) -> IndexedVariable {
pub fn index(&self, index: usize) -> IndexedVariable {
IndexedVariable {
var: *self,
index,
optimized,
optimized: self.is_optimized(),
}
}
}
Expand Down
76 changes: 66 additions & 10 deletions crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum Instruction {
var: Variable,
},
Modulo(BinaryInstruction),
Remainder(BinaryInstruction),
Add(BinaryInstruction),
Fma {
a: Variable,
Expand Down Expand Up @@ -238,12 +239,7 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
min_value,
max_value,
out,
} => f.write_fmt(format_args!(
"
{out} = min({input}, {max_value});
{out} = max({out}, {min_value});
"
)),
} => Clamp::format(f, input, min_value, max_value, out),
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Expand Down Expand Up @@ -277,6 +273,7 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
Instruction::Wrap(it) => f.write_fmt(format_args!("{it}")),
Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
Instruction::Wmma(it) => f.write_fmt(format_args!("{it}")),
Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out),
}
}
}
Expand All @@ -294,14 +291,73 @@ impl Fma {
let num = out.item().vectorization;

for i in 0..num {
let ai = a.index(i, false);
let bi = b.index(i, false);
let ci = c.index(i, false);
let outi = out.index(i, false);
let ai = a.index(i);
let bi = b.index(i);
let ci = c.index(i);
let outi = out.index(i);

f.write_fmt(format_args!("{outi} = fma({ai}, {bi}, {ci});\n"))?;
}

Ok(())
}
}

struct Clamp;

impl Clamp {
fn format(
f: &mut core::fmt::Formatter<'_>,
input: &Variable,
min_value: &Variable,
max_value: &Variable,
out: &Variable,
) -> core::fmt::Result {
let input = input.optimized();
let min_value = min_value.optimized();
let max_value = max_value.optimized();
let out = out.optimized();
let num = out.item().vectorization;

for i in 0..num {
let inputi = input.index(i);
let mini = min_value.index(i);
let maxi = max_value.index(i);
let outi = out.index(i);

f.write_fmt(format_args!(
"{outi} = max({mini}, min({maxi}, {inputi}));\n"
))?;
}

Ok(())
}
}

struct Remainder;

impl Remainder {
fn format(
f: &mut core::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> core::fmt::Result {
let lhs = lhs.optimized();
let rhs = rhs.optimized();
let out = out.optimized();
let num = out.item().vectorization;

for i in 0..num {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);
let outi = out.index(i);

f.write_fmt(format_args!(
"{outi} = {lhsi} - {rhsi} * floor({lhsi} / {rhsi});\n"
))?;
}

Ok(())
}
}
10 changes: 5 additions & 5 deletions crates/cubecl-cuda/src/compiler/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ pub trait Unary {
) -> std::fmt::Result {
let optimized = Variable::optimized_args([*input, *out]);
let [input, out] = optimized.args;
let (is_optimized, index, elem) = match optimized.optimization_factor {
Some(factor) => (true, index / factor, out.elem()),
None => (false, index, elem),
let (index, elem) = match optimized.optimization_factor {
Some(factor) => (index / factor, out.elem()),
None => (index, elem),
};

for i in 0..index {
let inputi = input.index(i, is_optimized);
let outi = out.index(i, is_optimized);
let inputi = input.index(i);
let outi = out.index(i);

Self::format_scalar(f, inputi, outi, elem)?;
}
Expand Down

0 comments on commit c3ef475

Please sign in to comment.