From 85d5045117055466ece970dbeec28c9aab5053b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Tue, 22 Oct 2024 00:01:49 +0300 Subject: [PATCH] fix to float32 --- turbo_alignment/dataset/chat/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index e6a9d93..5a20c4c 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -307,7 +307,7 @@ def _encode( encoded_record: dict[str, Any] = { # 'id': record.id, FIXME: dont work with collators - 'input_ids': torch.LongTensor(input_ids.astype(np.int64)), + 'input_ids': torch.LongTensor(input_ids.astype(np.float32)), 'labels': torch.LongTensor(labels), 'attention_mask': torch.ones(input_ids.shape, dtype=torch.int64), }