Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Aug 26, 2024
1 parent 0c9b446 commit ac4948e
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _generate_from_batch_records(
inputs=batched_input_ids,
attention_mask=batched_attention_mask,
generation_config=self._transformers_generator_parameters,
tokenizer=self._tokenizer,
# tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
)

Expand Down Expand Up @@ -84,7 +84,7 @@ def _generate_from_single_record(
inputs=input_ids,
attention_mask=attention_mask,
generation_config=self._transformers_generator_parameters,
tokenizer=self._tokenizer,
# tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
)

Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/generators/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _generate_from_single_record(
output_indices = self._model.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
tokenizer=self._tokenizer,
# tokenizer=self._tokenizer,
generation_config=self._transformers_generator_parameters,
)

Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/generators/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _generate_from_single_record(
answer_indices, document_indices, doc_scores = self._model.generate(
inputs=input_ids,
generation_config=self._transformers_generator_parameters,
tokenizer=self._tokenizer.current_tokenizer,
# tokenizer=self._tokenizer.current_tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
)

Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/modeling/rag/rag_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def generate(
input_ids=joined_input_ids,
generation_config=generation_config,
pad_token_id=self.tokenizer.pad_token_id,
tokenizer=kwargs.get('tokenizer', None),
# tokenizer=kwargs.get('tokenizer', None),
)
# TODO chose max-prob sequence with accounting for doc probs
only_answer_output = output_sequences[:, joined_input_ids.shape[-1] :]
Expand Down

0 comments on commit ac4948e

Please sign in to comment.