Skip to content

Commit

Permalink
Replace tuple unpacking with jnp.s_ for variable indexing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668199363
  • Loading branch information
ychzhang authored and copybara-github committed Aug 28, 2024
1 parent 8bbf1f5 commit b391500
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aqt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
"""Accurate Quantized Training library."""

__version__ = "0.8.0"
__version__ = "0.8.1"
2 changes: 1 addition & 1 deletion aqt/jax/v2/tiled_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def unapply(self, tiled_x: jnp.ndarray) -> jnp.ndarray:
num_bcast_axes = len(self.get_broadcasted_tile_map_indexes())
# All elements of broadcast axes are the same, thus we take the first one.
first_index = (0,) * num_bcast_axes
x = tiled_x[*first_index]
x = tiled_x[jnp.s_[first_index]]
x = x.reshape(self.untiled_shape)
return x

Expand Down

0 comments on commit b391500

Please sign in to comment.