Skip to content

Commit

Permalink
Fix Lemma parsing.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Nov 3, 2023
1 parent 98842a2 commit b54a3d9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
21 changes: 14 additions & 7 deletions spacy_llm/tasks/lemma/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,30 @@

def parse_responses_v1(
task: LemmaTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]]
) -> Iterable[Iterable[List[List[str]]]]:
) -> Iterable[List[List[List[str]]]]:
"""Parses LLM responses for spacy.Lemma.v1.
task (LemmaTask): Task instance.
shards (Iterable[Iterable[Doc]]): Doc shards.
responses (Iterable[Iterable[str]]): LLM responses.
RETURNS (Iterable[List[str]]): Lists of 2-lists (token: lemmatized token) per doc/response.
RETURNS (Iterable[List[List[List[str]]]]): Lists of 2-lists per token (token: lemmatized token) and shard/response
and doc.
"""
for responses_for_doc in responses:
results_for_doc: List[List[List[str]]] = []
for response in responses_for_doc:
results_for_shard = [
[pr_part.strip() for pr_part in pr.split(":")]
for pr in response.replace("Lemmatized text:", "")
.replace("'''", "")
.strip()
.split("\n")
]
results_for_doc.append(
# Malformed responses might have a length != 2, in which case they are discarded.
[
[pr_part.strip() for pr_part in pr.split(":")]
for pr in response.replace("Lemmatized text:", "")
.replace("'''", "")
.strip()
.split("\n")
result_for_token
for result_for_token in results_for_shard
if len(result_for_token) == 2
]
)

Expand Down
1 change: 0 additions & 1 deletion spacy_llm/tasks/lemma/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def parse_responses(
self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]]
) -> Iterable[Doc]:
shards_teed = tee(shards, 2)

for shards_for_doc, lemmas_for_doc in zip(
shards_teed[0], self._parse_responses(self, shards_teed[1], responses)
):
Expand Down
4 changes: 2 additions & 2 deletions spacy_llm/tests/tasks/test_lemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def test_lemma_config(cfg_string, request):
@pytest.mark.parametrize(
"cfg_string",
[
"zeroshot_cfg_string",
"fewshot_cfg_string",
# "zeroshot_cfg_string",
# "fewshot_cfg_string",
"ext_template_cfg_string",
],
)
Expand Down

0 comments on commit b54a3d9

Please sign in to comment.