Skip to content

Commit

Permalink
add printing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Nov 1, 2023
1 parent 651d52a commit 309428b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
4 changes: 4 additions & 0 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None:
training issues.
"""

print(f"Saving debug log to {out_dir}/debug.log")
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s:\n%(message)s",
filename=out_dir / "debug.log",
filemode="w",
)

if len(datasets) == 0:
logging.warning("No datasets found!")

for ds_name, ds in datasets:
logging.info(
"=========================================\n"
Expand Down
10 changes: 9 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def tokenize_dataset(
for example in prompt_ds:
num_variants = len(example["template_names"])

# Check if we've yielded enough examples
# Check if we've appended enough examples
if len(out_records) >= max_examples * num_variants:
break

Expand Down Expand Up @@ -258,7 +258,15 @@ def tokenize_dataset(
# print an example text to stdout
if len(out_records) == 0:
print(f"Example text: {record_variants[0]['text']}")
neg_id, pos_id = record_variants[0]["answer_ids"]
print(f'\tneg choice token: "{tokenizer.decode(neg_id)}"')
print(f'\tpos choice token: "{tokenizer.decode(pos_id)}"')
out_records.extend(record_variants)
else:
print(
f"WARNING: reached end of dataset {ds_names[0]} before collecting "
f"{max_examples} examples (only got {len(out_records)})."
)

# transpose the list of dicts into a dict of lists
out_records = {k: [d[k] for d in out_records] for k in out_records[0]}
Expand Down
1 change: 1 addition & 0 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def load_prompts(
fewshot_iter = None

if label_column in ds.features and balance:
print(f"Balancing dataset by {label_column}")
ds = BalancedSampler(
ds.to_iterable_dataset(),
set(label_choices),
Expand Down

0 comments on commit 309428b

Please sign in to comment.