Skip to content

Commit

Permalink
Add set_candidate_selector().
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent f9991d5 commit 49b93da
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions spacy_llm/tasks/entity_linker/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type

import jinja2
from spacy import Language
from spacy import Language, Vocab
from spacy.pipeline import EntityLinker
from spacy.tokens import Doc, Span
from spacy.training import Example
Expand Down Expand Up @@ -69,9 +69,15 @@ def initialize(
n_prompt_examples=n_prompt_examples,
fetch_entity_info=self.fetch_entity_info,
)
self.set_candidate_selector(candidate_selector, nlp.vocab)

def set_candidate_selector(
self, candidate_selector: CandidateSelector, vocab: Vocab
) -> None:
"""Sets candidate selector instance."""
self._candidate_selector = candidate_selector
if isinstance(self._candidate_selector, InitializableCandidateSelector):
self._candidate_selector.initialize(nlp.vocab)
self._candidate_selector.initialize(vocab)

def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]:
environment = jinja2.Environment()
Expand Down

0 comments on commit 49b93da

Please sign in to comment.