Skip to content

Commit

Permalink
Implement scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jun 10, 2024
1 parent 0936406 commit a5aff0f
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions candle-core/src/tensor_indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,19 @@ impl Tensor {
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
Ok(from_storage(storage, dims, op, false))
}

pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
if self.dims().len() != indexes.dims().len() {
bail!("Self and indexes must have the same number of dimensions");
}
let dim = dim.to_index(self.shape(), "gather")?;
for (i, (&index_d, &self_d)) in indexes.dims().iter().zip(self.dims().iter()).enumerate() {
if i != dim && index_d > self_d {
bail!("Required that index dim <= self dim at every dim except than `dim`, got {index_d} > {self_d}");
}
}
let zeroed =
self.index_add(indexes, &(self.index_select(indexes, dim)? * -1.0f64)?, dim)?;
zeroed.index_add(indexes, source, dim)
}
}

0 comments on commit a5aff0f

Please sign in to comment.