From 2d80b2a265400962274d7a415e87e2aa23b34d2b Mon Sep 17 00:00:00 2001 From: Jehovan <143548520+Jehovan@users.noreply.github.com> Date: Fri, 8 Sep 2023 00:16:34 +0200 Subject: [PATCH] beam_search fix for running with torch.use_deterministic_algorithms(True) (#1096) --- CHANGELOG.md | 6 ++++++ sockeye/__init__.py | 2 +- sockeye/beam_search.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f4c632a5..317a51cef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.37] + +### Fixed + +- Fixed beam_search for running with torch.use_deterministic_algorithms(True) + ## [3.1.36] ### Changed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index a8897516d..b017b08c9 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.36' +__version__ = '3.1.37' diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 34e65e4f2..1476be88b 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -905,7 +905,7 @@ def forward(self, # locations of each batch item when first dimension is (batch * beam) batch_indices = pt.arange(0, batch_size * self.beam_size, self.beam_size, dtype=pt.int64, device=self.device) first_step_mask = pt.full((batch_size * self.beam_size, 1), fill_value=np.inf, device=self.device, dtype=self.dtype) - first_step_mask[batch_indices] = 0.0 + first_step_mask[batch_indices] = pt.full((batch_size, 1), fill_value=0.0, device=self.device, dtype=self.dtype) if target_prefix is not None: first_step_mask = utils.adjust_first_step_masking(target_prefix, first_step_mask)