You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, in multiple places (where exactly?) we have the assumption, for some tensor x: Tensor that max(x.dims[i].dyn_size_ext.raw_tensor) == x.raw_tensor.shape[i].
We want to support the case where max(x.dims[i].dyn_size_ext.raw_tensor) < x.raw_tensor.shape[i]. Specifically, we always have x.dims[i].capacity == x.raw_tensor.shape[i], thus we want to support the case dim.capacity > max(dim.dyn_size_ext.raw_tensor).
This is e.g. needed for JAX, where we only can have static shapes. So there we would compile our code for a few predefined batch sizes (fixed batch_dim, fixed spatial dim), and then the batching would prepare the batch fitting for one of these predefined sizes (e.g. via bucketing). But this often means that we don't reach the max possible seq len for this batch.
The text was updated successfully, but these errors were encountered:
Currently, in multiple places (where exactly?) we have the assumption, for some tensor
x: Tensor
thatmax(x.dims[i].dyn_size_ext.raw_tensor) == x.raw_tensor.shape[i]
.We want to support the case where
max(x.dims[i].dyn_size_ext.raw_tensor) < x.raw_tensor.shape[i]
. Specifically, we always havex.dims[i].capacity == x.raw_tensor.shape[i]
, thus we want to support the casedim.capacity > max(dim.dyn_size_ext.raw_tensor)
.This is e.g. needed for JAX, where we only can have static shapes. So there we would compile our code for a few predefined batch sizes (fixed
batch_dim
, fixed spatial dim), and then the batching would prepare the batch fitting for one of these predefined sizes (e.g. via bucketing). But this often means that we don't reach the max possible seq len for this batch.The text was updated successfully, but these errors were encountered: