From ef9df64044afad7a062ce39292e7ae7dfdcb7051 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 9 Jan 2025 05:39:00 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 713639384 --- official/projects/pix2seq/modeling/pix2seq_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/official/projects/pix2seq/modeling/pix2seq_model.py b/official/projects/pix2seq/modeling/pix2seq_model.py index 0510ce39a7..9985cbb7e7 100644 --- a/official/projects/pix2seq/modeling/pix2seq_model.py +++ b/official/projects/pix2seq/modeling/pix2seq_model.py @@ -327,9 +327,17 @@ def checkpoint_items( # For backward-compatibility with prior checkpoints, the first backbone # should be named "backbone" and the second one should be named # "backbone_2", etc. - items = dict(backbone=self.backbones[0], transformer=self.transformer) + items = dict( + backbone=self.backbones[0], + transformer=self.transformer, + stem_projection=self._stem_projections[0], + stem_ln=self._stem_lns[0], + ) for i in range(1, len(self.backbones)): items[f"backbone_{i+1}"] = self.backbones[i] + items[f"stem_projection_{i+1}"] = self._stem_projections[i] + items[f"stem_ln_{i+1}"] = self._stem_lns[i] + return items def _generate_image_mask(