Skip to content

Commit

Permalink
Use a simple int as a prompt length instead of a tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 30, 2024
1 parent f9c1c94 commit 3204a26
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 256 deletions.
4 changes: 1 addition & 3 deletions sentence_transformers/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
for key, value in tokenized.items():
batch[f"{column_name}_{key}"] = value
if prompt_len is not None:
batch[f"{column_name}_prompt_length"] = torch.tensor(
[prompt_len] * len(values), device=batch[f"{column_name}_input_ids"].device, dtype=torch.long
)
batch[f"{column_name}_prompt_length"] = prompt_len
return batch

def add_prompts_to_column(
Expand Down
Loading

0 comments on commit 3204a26

Please sign in to comment.