Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Oct 27, 2023
1 parent 3bc99b2 commit dd745f4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
1 change: 0 additions & 1 deletion elk/extraction/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def maybe_unsqueeze(v):
inputs_cuda = pytree_map(
lambda v: maybe_unsqueeze(v.to(device)), input_record
)
# TODO: have model kwargs so we don't have to duplicate kwargs at each row
outputs = model(**inputs_cuda, **model_kwargs)

if callable(closure):
Expand Down
7 changes: 6 additions & 1 deletion elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def load_prompts(
ds = assert_type(Dataset, ds_dict[split_name])
if "row_id" not in ds.column_names:
ds = ds.add_column("row_id", range(len(ds))) # type: ignore
else:
print("Found row_id column, using it as the prompt ID")
ds = ds.shuffle(seed=seed)

prompter, using_blank = get_prompter(ds_name, config_name, template_path)
Expand Down Expand Up @@ -135,7 +137,10 @@ def _convert_to_prompts(

for template in templates:
statement = template.apply(example)
prompt_counter[statement] += 1

choices = template.get_fixed_answer_choices_list()
choices = tuple(choices) if choices is not None else None
prompt_counter[(statement, choices)] += 1

if fewshot_iter is not None:
# Infinite iterator so we don't need to worry about StopIteration
Expand Down
3 changes: 2 additions & 1 deletion tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def map_fn(ex: dict) -> dict:
input_ids = tokenizer(ex["text"], add_special_tokens=True)["input_ids"]
out_record["input_ids"] = [input_ids + suffix_tokens] # type: ignore
answer_ids = [
tokenizer.encode(s, add_special_tokens=False)[0] for s in ["False", "True"]
tokenizer.encode(s, add_special_tokens=False)[0]
for s in [" False", " True"]
]
out_record["answer_ids"] = answer_ids
return out_record
Expand Down

0 comments on commit dd745f4

Please sign in to comment.