Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Dec 8, 2023
1 parent e728e2c commit 4070efa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion spacy_llm/tasks/raw/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def make_shard_reducer() -> ShardReducer:
@registry.llm_tasks("spacy.Raw.v1")
def make_raw_task(
template: str = DEFAULT_RAW_TEMPLATE_V1,
field: str = "reply",
field: str = "llm_reply",
parse_responses: Optional[TaskResponseParser[RawTask]] = None,
prompt_example_type: Optional[Type[FewshotExample]] = None,
examples: ExamplesConfigType = None,
Expand Down
2 changes: 0 additions & 2 deletions spacy_llm/tasks/templates/raw.v1.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
Read the instructions provided after "Text:" and reply after "Reply:".
{# whitespace #}
{%- if prompt_examples -%}
Below are some examples (only use these as a guide):
{# whitespace #}
Expand Down
15 changes: 6 additions & 9 deletions spacy_llm/tests/tasks/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_raw_predict(cfg_string, request):
cfg = request.getfixturevalue(cfg_string)
orig_config = Config().from_str(cfg)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp("What's the weather like?")._.reply
assert nlp("What's the weather like?")._.llm_reply


@pytest.mark.external
Expand All @@ -169,7 +169,7 @@ def test_raw_io(cfg_string, request):
nlp.to_disk(tmpdir)
nlp2 = spacy.load(tmpdir)
assert nlp2.pipe_names == ["llm"]
assert nlp2("I've watered the plants.")._.reply
assert nlp2("I've watered the plants.")._.llm_reply


def test_jinja_template_rendering_without_examples():
Expand All @@ -188,8 +188,6 @@ def test_jinja_template_rendering_without_examples():
assert (
prompt.strip()
== f"""
Read the instructions provided after "Text:" and reply after "Reply:".
Now follows the text you should read and reply to.
Text:
{text}
Expand Down Expand Up @@ -222,8 +220,6 @@ def test_jinja_template_rendering_with_examples(examples_path):
assert (
prompt.strip()
== f"""
Read the instructions provided after "Text:" and reply after "Reply:".
Now follows the text you should read and reply to.
Text:
{text}
Expand Down Expand Up @@ -254,17 +250,18 @@ def test_external_template_actually_loads():
@pytest.mark.parametrize("n_prompt_examples", [-1, 0, 1, 2])
def test_raw_init(noop_config, n_prompt_examples: int):
config = Config().from_str(noop_config)
nlp = assemble_from_config(config)
with pytest.warns(UserWarning, match="Task supports sharding"):
nlp = assemble_from_config(config)

examples = []
text = "How much wood would a woodchuck chuck if a woodchuck could chuck wood?"
gold_1 = nlp.make_doc(text)
gold_1._.reply = "Plenty"
gold_1._.llm_reply = "Plenty"
examples.append(Example(nlp.make_doc(text), gold_1))

text = "Who sells seashells by the seashore?"
gold_2 = nlp.make_doc(text)
gold_2._.reply = "Shelly"
gold_2._.llm_reply = "Shelly"
examples.append(Example(nlp.make_doc(text), gold_2))

_, llm = nlp.pipeline[0]
Expand Down

0 comments on commit 4070efa

Please sign in to comment.