diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index 5c5e09a..dd19c1b 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -3,12 +3,12 @@ from itertools import accumulate from pathlib import Path from typing import Any, overload -from typing_extensions import Self import numpy as np import numpy.typing as npt import torch from transformers import PreTrainedTokenizerBase +from typing_extensions import Self from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.logging import get_project_logger @@ -375,8 +375,7 @@ def get_slice(self, start: int, end: int) -> Self: new_instance.records = self.records[start:end] new_instance.original_records_map = { - record['id']: self.get_original_record_by_id(record['id']) - for record in new_instance.records + record['id']: self.get_original_record_by_id(record['id']) for record in new_instance.records } return new_instance