Skip to content

Commit

Permalink
remove bos eos when concating
Browse files Browse the repository at this point in the history
Signed-off-by: yintong-lu <[email protected]>
  • Loading branch information
yintong-lu committed Jun 6, 2024
1 parent 0975c5d commit 89b685e
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,41 @@ def filter_func(example):
def concat_dataset_element(dataset):
input_ids, concat_input_ids = [eg['input_ids'] for eg in dataset], []
attention_mask_list, attention_mask = [], torch.ones([1, seqlen]).to(torch.int64)
buffer_input_id = torch.Tensor()
buffer_input_id = torch.Tensor().to(torch.int64)
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
os_cnt, have_bos, have_eos = 0, False, False

for input_id in input_ids:
if buffer_input_id.shape[-1] + input_id.shape[-1] > seqlen:
idx_keep = seqlen - buffer_input_id.shape[-1]
concat_input_ids.append(torch.cat([buffer_input_id, input_id[:idx_keep]]).to(torch.int64))
if input_id[0] == bos_token_id:
input_id = input_id[1:]
os_cnt, have_bos = os_cnt + 1, True
if input_id[-1] == eos_token_id:
input_id = input_id[:-1]
os_cnt, have_eos = os_cnt + 1, True

if buffer_input_id.shape[-1] + input_id.shape[-1] + os_cnt > seqlen:
idx_keep = seqlen - buffer_input_id.shape[-1] - os_cnt
input_id_to_append = [buffer_input_id, input_id[:idx_keep]]
if have_bos:
input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append
if have_eos:
input_id_to_append.append(torch.tensor([eos_token_id]))

concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64))
attention_mask_list.append(attention_mask)
buffer_input_id = input_id[idx_keep:]
else:
buffer_input_id = torch.cat([buffer_input_id, input_id])
if buffer_input_id.shape[-1] == seqlen:
concat_input_ids.append(buffer_input_id.to(torch.int64))

if buffer_input_id.shape[-1] + os_cnt == seqlen:
input_id_to_append = [buffer_input_id]
if have_bos:
input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append
if have_eos:
input_id_to_append.append(torch.tensor([eos_token_id]))
concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64))
attention_mask_list.append(attention_mask)
buffer_input_id = torch.Tensor()
buffer_input_id = torch.Tensor().to(torch.int64)
data = [{'input_ids': a, 'attention_mask': b} for a, b in zip(concat_input_ids, attention_mask_list)]
import datasets
dataset_new = datasets.Dataset.from_list(data)
Expand Down Expand Up @@ -324,7 +346,8 @@ def concat_dataset_element(dataset):
for i in range(len(datasets)):
name = dataset_names[i].split(':')[0]
if name not in data_lens:
target_cnt = (n_samples - cnt) // (len(datasets) - len(data_lens)) if data_lens else (n_samples - cnt) // (len(datasets) - i)
target_cnt = (n_samples - cnt) // (len(datasets) - len(data_lens)) if data_lens \
else (n_samples - cnt) // (len(datasets) - i)
target_cnt = min(target_cnt, len(datasets[i]))
cnt += target_cnt
else:
Expand Down

0 comments on commit 89b685e

Please sign in to comment.