Skip to content

Commit

Permalink
Add tensor.one_hot int operation (#2413)
Browse files Browse the repository at this point in the history
* make zero_like, ones_like and full_like general numeric tensor functions,
correct some docs.

* add argtopk

* rename one_hot, correct float one_hot doc in book

* remove argtopk

* correct funtion name in book

---------

Co-authored-by: Tiago Sanona <[email protected]>
  • Loading branch information
tsanona and Tiago Sanona authored Oct 25, 2024
1 parent d5e8e31 commit 2775ec3
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 57 deletions.
62 changes: 32 additions & 30 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
| `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value) |
| `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` |
| `tensor.greater(other)` | `tensor.gt(other)` |
| `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` |
Expand Down Expand Up @@ -221,6 +222,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` |
| `tensor.neg()` or `-tensor` | `-tensor` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
Expand All @@ -243,41 +245,40 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |

### Float Operations

Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
| -------------------------------------------- | ---------------------------------- |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` |
| `tensor.floor()` | `tensor.floor()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.one_hot(index, num_classes, device)` | N/A |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |
| `tensor.zeros_like()` | `torch.zeros_like(tensor)` |
| Burn API | PyTorch Equivalent |
|-----------------------------------------------| ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` |
| `tensor.floor()` | `tensor.floor()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |

### Int Operations

Expand All @@ -291,6 +292,7 @@ Those operations are only available for `Int` tensors.
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
| `tensor.cartesian_grid(shape, device)` | N/A |
| `tensor.one_hot(num_classes)` | N/A |

### Bool Operations

Expand Down
31 changes: 26 additions & 5 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{backend::Backend, BasicOps, Shape, Tensor};
use crate::{backend::Backend, BasicOps, Int, Shape, Tensor};
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
Expand Down Expand Up @@ -30,8 +30,8 @@ use core::ops::Range;
/// Maybe the Backend API should return a result for each operation, which would allow handling
/// all checks, even the ones that can't be efficiently checked before performing an operation,
/// such as the `index_select` operation. The downside of that approach is that all backend
/// implementation might re-implement the same checks, which may result in uncessary code
/// duplication. Maybe a combination of both strategies could help to cover all usecases.
/// implementation might re-implement the same checks, which may result in unnecessary code
/// duplication. Maybe a combination of both strategies could help to cover all use cases.
pub(crate) enum TensorCheck {
Ok,
Failed(FailedTensorCheck),
Expand Down Expand Up @@ -447,7 +447,7 @@ impl TensorCheck {
check
}

pub(crate) fn one_hot(index: usize, num_classes: usize) -> Self {
pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self {
let mut check = Self::Ok;
if index >= num_classes {
check = check.register(
Expand All @@ -461,6 +461,27 @@ impl TensorCheck {
check
}

pub(crate) fn one_hot_tensor<B: Backend>(
index_tensor: Tensor<B, 1, Int>,
num_classes: usize,
) -> Self {
let mut check = Self::Ok;
if index_tensor
.clone()
.greater_equal_elem(num_classes as i32)
.any()
.into_scalar()
{
check = check.register(
"One Hot",
TensorError::new(format!(
"Can't create a one hot tensor from ({index_tensor:?}) containing indexes greater or equal to the number of classes ({num_classes})",
)),
);
}
check
}

pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;

Expand Down Expand Up @@ -1152,7 +1173,7 @@ impl TensorError {
/// Module where we defined macros that can be used only in the project.
pub(crate) mod macros {
/// We use a macro for all checks, since the panic message file and line number will match the
/// function that does the check instead of a the generic error.rs crate private unrelated file
/// function that does the check instead of a generic error.rs crate private unrelated file
/// and line number.
macro_rules! check {
($check:expr) => {
Expand Down
26 changes: 5 additions & 21 deletions crates/burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
///
/// # Notes
///
/// This won't necessary reuse the same tensor data/buffer, but it should if there is
/// This won't necessarily reuse the same tensor data/buffer, but it should if there is
/// no other reference pointing to the same tensor.
///
/// Wrapping operations with inplace is not an optimization, it's mainly there if you
Expand Down Expand Up @@ -147,7 +147,7 @@ where
}

/// Returns a new tensor with the same shape and device as the current tensor and the data
/// casted to Integer.
/// cast to Integer.
///
/// # Example
///
Expand All @@ -165,22 +165,6 @@ where
Tensor::new(B::float_into_int(self.primitive.tensor()))
}

/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
pub fn zeros_like(&self) -> Self {
Tensor::new(TensorPrimitive::Float(B::float_zeros(
self.shape(),
&self.device(),
)))
}

/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
pub fn ones_like(&self) -> Self {
Tensor::new(TensorPrimitive::Float(B::float_ones(
self.shape(),
&self.device(),
)))
}

/// Returns a new tensor with the same shape and device as the current tensor filled random
/// values sampled from the given distribution.
pub fn random_like(&self, distribution: Distribution) -> Self {
Expand All @@ -207,7 +191,7 @@ where
/// }
/// ```
pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
check!(TensorCheck::one_hot(index, num_classes));
check!(TensorCheck::one_hot_index(index, num_classes));

let mut dims = [1; D];
dims[D - 1] = num_classes;
Expand All @@ -226,7 +210,7 @@ where
///
/// # Panics
///
/// If the two tensors dont' have a compatible shape.
/// If the two tensors don't have a compatible shape.
pub fn matmul(self, other: Self) -> Self {
check!(TensorCheck::matmul(&self, &other));
Self::new(TensorPrimitive::Float(B::float_matmul(
Expand Down Expand Up @@ -299,7 +283,7 @@ where
}
}

/// Mark the tensor as tracked or untracked depending on the require grad argument.
/// Mark the tensor as tracked or untracked depending on the require_grad argument.
/// When tracked, the gradients will be available after the backward pass.
///
/// This function does nothing when autodiff is not enabled.
Expand Down
32 changes: 31 additions & 1 deletion crates/burn-tensor/src/tensor/api/int.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::check;
use crate::check::TensorCheck;
use crate::{
backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive,
};
Expand Down Expand Up @@ -27,6 +29,34 @@ where
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::int_arange_step(range, step, device))
}

/// Create a one hot tensor from an index tensor.
///
/// # Arguments
///
/// * `num_classes` - The number of classes to use in encoding.
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// let indices: Tensor<B, 1, Int> = Tensor::from_ints([0, 1, 2, 3], &device);
/// let one_hot = indices.one_hot(4);
/// println!("{}", one_hot.to_data());
/// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
/// }
/// ```
pub fn one_hot(self, num_classes: usize) -> Tensor<B, 2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
let [num_samples] = self.dims();
let indices = self.unsqueeze();
let values = indices.ones_like();
Tensor::zeros([num_samples, num_samples], &indices.device()).scatter(1, indices, values)
}
}

impl<const D: usize, B> Tensor<B, D, Int>
Expand All @@ -52,7 +82,7 @@ where
}

/// Returns a new tensor with the same shape and device as the current tensor and the data
/// casted to Float.
/// cast to Float.
///
/// # Example
///
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,23 @@ where
Self::new(K::zeros(shape, device))
}

/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
pub fn zeros_like(&self) -> Self {
Self::zeros(self.shape(), &self.device())
}

/// Create a tensor of the given shape where each element is one.
pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
let shape = shape.into();
check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
Self::new(K::ones(shape, device))
}

/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
pub fn ones_like(&self) -> Self {
Self::ones(self.shape(), &self.device())
}

/// Create a tensor of the given shape where each element is equal to the provided value.
pub fn full<S: Into<Shape>, E: ElementConversion>(
shape: S,
Expand All @@ -124,6 +134,11 @@ where
Self::new(K::full(shape, fill_value, device))
}

///Returns a new tensor with the same shape and device as the current tensor filled with the provided value.
pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
Self::full(self.shape(), fill_value, &self.device())
}

/// Aggregate all elements in the tensor with the mean operation.
pub fn mean(self) -> Tensor<B, 1, K> {
Tensor::new(K::mean(self.primitive))
Expand Down

0 comments on commit 2775ec3

Please sign in to comment.