-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
executable file
·65 lines (56 loc) · 2.42 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import random
import json
import torch
from torch.utils.data import Dataset
from transformers import DefaultDataCollator
class InstrutionDataset(Dataset):
def __init__(self, data_path, prefix=""):
self.dataset = []
with open(data_path, "r", encoding="utf-8") as fh:
for i, line in enumerate(fh):
sample = json.loads(line.strip())
self.dataset.append(
{"input": 'instruction: ' + prefix + sample["instruction"] + '\n' + 'input: ' + sample[
"input"] + '\n' + "answer: ",
"answer": sample["output"]})
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
return self.dataset[item]
class InstrutionCollator(DefaultDataCollator):
def __init__(self, tokenizer, max_len, max_input_len):
self.max_len = max_len
self.max_input_len = max_input_len
self.tokenizer = tokenizer
def __post_init__(self):
super().__post_init__()
self.rng = random.Random()
def __call__(self, examples):
input_ids_list = []
labels_list = []
max_tgt_len = self.max_len - self.max_input_len - 3
for example in examples:
input = example["input"]
answer = example["answer"]
src_tokens = self.tokenizer.tokenize(input)
if len(src_tokens) > self.max_input_len:
src_tokens = src_tokens[:self.max_input_len]
tgt_tokens = self.tokenizer.tokenize(answer)
if len(tgt_tokens) > max_tgt_len:
tgt_tokens = tgt_tokens[:max_tgt_len]
tokens = src_tokens + ["[gMASK]", "<sop>"] + tgt_tokens + ["<eop>"]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
context_length = input_ids.index(self.tokenizer.bos_token_id)
mask_position = context_length - 1
labels = [-100] * context_length + input_ids[mask_position + 1:]
pad_len = self.max_len - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [-100] * pad_len
labels_list.append(torch.LongTensor(labels))
input_ids_list.append(torch.LongTensor(input_ids))
input_ids = torch.stack(input_ids_list)
labels = torch.stack(labels_list)
return {
"input_ids": input_ids,
"labels": labels,
}