From a5aff0f11431c2ed3f60093ab98690d55e8576ba Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 9 Jun 2024 20:16:01 -0400 Subject: [PATCH] Implement scatter --- candle-core/src/tensor_indexing.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/candle-core/src/tensor_indexing.rs b/candle-core/src/tensor_indexing.rs index 140876456b..d5b987a0d4 100644 --- a/candle-core/src/tensor_indexing.rs +++ b/candle-core/src/tensor_indexing.rs @@ -376,4 +376,19 @@ impl Tensor { let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); Ok(from_storage(storage, dims, op, false)) } + + pub fn scatter(&self, indexes: &Self, source: &Self, dim: D) -> Result { + if self.dims().len() != indexes.dims().len() { + bail!("Self and indexes must have the same number of dimensions"); + } + let dim = dim.to_index(self.shape(), "gather")?; + for (i, (&index_d, &self_d)) in indexes.dims().iter().zip(self.dims().iter()).enumerate() { + if i != dim && index_d > self_d { + bail!("Required that index dim <= self dim at every dim except than `dim`, got {index_d} > {self_d}"); + } + } + let zeroed = + self.index_add(indexes, &(self.index_select(indexes, dim)? * -1.0f64)?, dim)?; + zeroed.index_add(indexes, source, dim) + } }