Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust Attention Mechanism and Dataset Handling #295

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None):
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape[1] == 1:
if queries.shape == 1:
is_causal = False
else:
if not is_causal:
Expand Down
123 changes: 120 additions & 3 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from functools import partial
from itertools import islice
import copy
from datasets import concatenate_datasets

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -545,21 +546,136 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):
return DataInfo(dataloader, sampler)


class JSONLDataset(Dataset):
def __init__(self, file_path, tokenizer, seq_len, padding_side):
self.padding_side = padding_side
self.urls = [file_path]
self.eot_token = 0
self.pad_token = 1
self.ignore_tok = -100
self.tokenizer = tokenizer
self.seq_len = seq_len
self.data, self.long_answer_tokens = self.load_data(file_path)
print(f"Loaded {len(self.data)} samples from {file_path}")

def load_data(self, file_path):
data = []
long_answer_tokens = []
with open(file_path, 'r') as f:
for line in f:
item = json.loads(line.strip())
chunks, long_answer = self.create_chunks(item)
data.append(chunks)
long_answer_tokens.append(long_answer)
# return data, long_answer_tokens
return torch.tensor(data), torch.tensor(long_answer_tokens)

def create_chunks(self, item):
inputs = self.tokenizer(item['instruction'] + item['input'])
outputs = self.tokenizer(item['output']) + [self.eot_token]

input_tokens = inputs + outputs
target_tokens = [self.ignore_tok] * len(inputs) + outputs
# if the tokens exceed the chunksize, truncate to chunksize
assert len(input_tokens) == len(target_tokens)
input_tokens = input_tokens[-self.seq_len:]
target_tokens = target_tokens[-self.seq_len:]

# if the input is less than chunksize, auto padding
input_tokens = self.pad_input(input_tokens, self.pad_token)
target_tokens = self.pad_input(target_tokens, self.ignore_tok)
return input_tokens, target_tokens

def pad_input(self, tokens, pad_token):
if len(tokens) < self.seq_len:
padding = [pad_token] * (self.seq_len - len(tokens))
if self.padding_side == "right":
tokens = tokens + padding
elif self.padding_side == "left":
tokens = padding + tokens
else:
raise Exception("PADDING SIDE should either be left or right")
return tokens

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
input_ids = self.data[idx]
target_ids = self.long_answer_tokens[idx]
if len(input_ids) != len(target_ids):
raise ValueError(f"Input and target sizes do not match at index {idx}: {input_ids.size()} vs {target_ids.size()}")

return input_ids, target_ids


def get_jsonl_dataloader(args, is_train, tokenizer=None, floor=True, epoch=0, data_key="json", force_num_samples=None):
file_paths = args.train_data if is_train else args.val_data
datasets = [JSONLDataset(file_path, tokenizer, args.seq_len, args.padding_side) for file_path in file_paths]

if is_train:
# todo if the dataset is consists of a list
dataset = datasets[0]
# Initialize shared_epoch

if is_train:
shared_epoch = SharedEpoch(epoch=epoch)
else:
shared_epoch = None

if is_train:
global_batch_size = args.per_gpu_batch_size * args.world_size
round_fn = math.floor if floor else math.ceil
total_num_batches = 0
total_num_samples = 0
# for dataset in datasets: ## dataset has already been concated
num_worker_batches = round_fn(len(dataset) / (global_batch_size * max(1, args.workers)))
num_batches = num_worker_batches * max(1, args.workers)
num_samples = num_batches * global_batch_size
total_num_batches += num_batches
total_num_samples += num_samples
else:
# For validation, just use the original dataset
dataset = datasets[0]
total_num_batches = math.ceil(len(dataset) / (args.per_gpu_val_batch_size * args.world_size))
total_num_samples = len(dataset)

# Create the dataloader
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
# shuffle = True
dataloader = DataLoader(
dataset,
batch_size=args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)

dataloader.num_batches = total_num_batches
dataloader.num_samples = total_num_samples
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch, sampler=sampler)


def get_dataset_fn(dataset_type):
if dataset_type == "synthetic":
return get_synthetic_dataset
elif dataset_type == "jsonl":
return get_jsonl_dataloader
else:
return get_wds_dataset


def get_data(args, epoch=0, tokenizer=None, skip_train=False, floor=True):
data = {}

if skip_train:
data["train"] = None
else:
if args.train_data or args.dataset_type == "synthetic":
# train data is treated as a shard list where all data is combined and tained on
args.train_num_samples = 1000
data["train"] = get_dataset_fn(args.dataset_type)(
args, is_train=True, epoch=epoch, tokenizer=tokenizer, data_key=args.data_key, floor=floor
)
Expand Down Expand Up @@ -711,8 +827,9 @@ def sample_chunk(chunk, args):
else:
raise Exception(f"Invalid sequence length: Sequence length {args.seq_len} > {chunk.shape[1]} Chunk size")

inputs = chunk[:, start_idx : start_idx + args.seq_len]
targets = chunk[:, start_idx + 1 : start_idx + args.seq_len + 1]
inputs = chunk[:, start_idx: start_idx + args.seq_len]
targets = chunk[:, start_idx + 1: start_idx + args.seq_len + 1]


# replace elements to be masked with with -100 (pytorch default xent ignore value)
if args.target_mask_left is not None or args.target_mask_individual is not None:
Expand Down
11 changes: 9 additions & 2 deletions open_lm/datapreprocess/make_assistant_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,17 @@ def dump_queue_to_buffer():

with get_item_reader(file_name) as item_reader:
for item in item_reader:
string = item["text"]
try:
tokens = remaining_tokens + enc(string) + [eot_token]
# Extract and concatenate the relevant fields
tokens = remaining_tokens + \
enc(item["QUESTION"]) +\
enc(item["CONTEXTS"]) +\
enc(item["LONG_ANSWER"]) +\
[eot_token]

remaining_tokens = []
# tokens = torch.tensor(tokens).unsqueeze(0) # Shape: (1, seq_len + 1)

except:
print("Failed to encode string.")
continue
Expand Down
20 changes: 17 additions & 3 deletions open_lm/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,34 @@ def evaluate(model, data, start_epoch, args, writer):
if i == dataloader.num_batches and not exhaust_loader:
break

(texts,) = batch
texts = torch.LongTensor(texts).to(device)

data_time_m.update(time.time() - end)

with autocast():
inputs, targets = sample_chunk(texts, args)
if args.dataset_type == "jsonl":
inputs, targets = batch
inputs = torch.LongTensor(inputs).to(device)
targets = torch.LongTensor(targets).to(device)
inputs = inputs[:, :-1]
targets = targets[:, 1:]
if is_master(args) and i == 0:
for target in targets:
print("decode", target[target!=-100])
else:
(texts,) = batch
texts = torch.LongTensor(texts).to(device)
inputs, targets = sample_chunk(texts, args)

out, _, _ = model(inputs) # [per_gpu_bs, seq_len, vocab_size]

bs, seq_len = targets.shape


targets = targets.reshape(-1)
total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len]
if is_master(args) and i == 0:
print("total_loss", total_loss)
print("loss not equal to zero", total_loss[total_loss!=0.0])

# cross entropy ignores -100 values in loss computation
mask = targets != -100
Expand Down
Loading