Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hongjin-su authored Dec 30, 2023
1 parent e24880e commit e749023
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def _get_train_sampler(self) :
)

def compute_loss(self, model, inputs, return_outputs=False):
for task_id in inputs['task_name']:
assert task_id==inputs['task_name'][0],f"Examples in the same batch should come from the same task, " \
f"but task {task_id} and task {inputs['task_name'][0]} are found"
for task_id in inputs['task_id']:
assert task_id==inputs['task_id'][0],f"Examples in the same batch should come from the same task, " \
f"but task {task_id} and task {inputs['task_id'][0]} are found"
cur_results = {}
for k in ['query', 'pos', 'neg']:
cur_inputs = {
Expand Down Expand Up @@ -447,12 +447,12 @@ def main():
def get_examples_raw(old_examples_raw, total_n, real_batch_size):
examples_raw = []
for idx in range(0, total_n, real_batch_size):
local_task_name = old_examples_raw[idx]['task_name']
local_task_name = old_examples_raw[idx]['task_id']
cur_batch = []
include_batch = True
for idx1 in range(idx, min(idx + real_batch_size, total_n)):
if not old_examples_raw[idx1]['task_name'] == local_task_name:
print(f'one batch in task {old_examples_raw[idx1]["task_name"]} is skipped')
if not old_examples_raw[idx1]['task_id'] == local_task_name:
print(f'one batch in task {old_examples_raw[idx1]["task_id"]} is skipped')
include_batch = False
break
else:
Expand All @@ -478,7 +478,7 @@ def get_examples_raw(old_examples_raw, total_n, real_batch_size):
train_examples_raw = train_examples_raw[:int(data_args.debug_mode)]

def get_dataset(examples_raw):
examples = {'query':[],'pos':[],'neg':[],'task_name':[]}
examples = {'query':[],'pos':[],'neg':[],'task_id':[]}
task_name_map = {}
total_num = len(examples_raw)
task_count = 0
Expand All @@ -492,10 +492,10 @@ def get_dataset(examples_raw):
cur_e[k][0] = ''
assert cur_e[k][0].startswith('Represent ') or cur_e[k][0]==''
examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k]))
if not cur_e['task_name'] in task_name_map:
task_name_map[cur_e['task_name']] = task_count
if not cur_e['task_id'] in task_name_map:
task_name_map[cur_e['task_id']] = task_count
task_count += 1
examples['task_name'].append(task_name_map[cur_e['task_name']])
examples['task_id'].append(task_name_map[cur_e['task_id']])
return examples

train_raw_datasets = DatasetDict({'train':Dataset.from_dict(get_dataset(train_examples_raw))})
Expand Down Expand Up @@ -530,7 +530,7 @@ def preprocess_function(examples):
all_tokenized[k] = all_tokenized[k].tolist()
for k in keys:
all_tokenized[f'{key}_{k}'] = tokenized[k].tolist()
all_tokenized['task_name'] = examples['task_name']
all_tokenized['task_id'] = examples['task_id']
return all_tokenized

train_dataset = train_raw_datasets["train"]
Expand Down

0 comments on commit e749023

Please sign in to comment.