Skip to content

Commit

Permalink
Fix (llm): fix device issue for eval when not using default device
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Apr 30, 2024
1 parent 0c52c9a commit fb191f7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from tqdm import tqdm


def create_validation_dataloader(data, seqlen):
def create_validation_dataloader(data, seqlen, device):
nsamples = data['input_ids'].numel() // seqlen
val_dataloader = []
for i in tqdm(range(nsamples)):
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda()
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].to(device)
attention_mask = torch.ones_like(batch)
val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask})
return val_dataloader
Expand All @@ -41,7 +41,7 @@ def model_eval(model, valenc, seqlen):
for inps in valenc:
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inps['input_ids'][:, 1:].cuda()
shift_labels = inps['input_ids'][:, 1:].to(model.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def main():
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0)
val_data = get_wikitext2(
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0)
val_data = create_validation_dataloader(val_data, args.seqlen)
val_data = create_validation_dataloader(val_data, args.seqlen, model.device)
print("Data loaded.")

# Apply LN affine merging before inserting MHA layers
Expand Down

0 comments on commit fb191f7

Please sign in to comment.