diff --git a/aphrodite/modeling/models/granite.py b/aphrodite/modeling/models/granite.py index 842519635..68380d50c 100644 --- a/aphrodite/modeling/models/granite.py +++ b/aphrodite/modeling/models/granite.py @@ -439,11 +439,12 @@ def forward( return model_output def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> Optional[torch.Tensor]: - logits = self.logits_processor( - self.lm_head, hidden_states, sampling_metadata - ) + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + if logits is not None: + logits /= self.config.logits_scaling return logits def sample(