Skip to content

Commit

Permalink
Support arbitrarily long docs (#332)
Browse files Browse the repository at this point in the history
* Add context length info. Refactor BuiltinTask and models to facilitate this.

* Add token count estimator plumbing.

* Add plumbing for mapper and reducer.

* Add ShardMapper prototype.

* Integrating mapping into prompt generation workflow.

* Update response parsing and component to support sharding (WIP).

* Fix shard & prompt flow.

* Fix shard & prompt flow.

* Remove todo comments.

* Fix Anthropic, Cohere, NoOp model tests.

* Fix test_llm_pipe().

* Fix type checking test.

* Fix span parsing tests.

* Fix internal tests.

* Fix _CountTask.

* Fix sentiment and summarization tasks and tests.

* Fix Azure connection URL. Fix Model test pings.

* Fix Lemma parsing.

* Start work on doc-to-shard property copying.

* Fix REL doc preprocessing.

* Remove comment on doc attribute handling during sharding, as this is done by spaCy's slicing directly.

* Add reducer implementations.

* Implement outstanding task reducers.

* Add shardable/non-shardable LLM task typing distinction. Add support for handling both types of tasks. Update tests.

* Fix EL task.

* Fix EL tokenization and highlighting partially.

* Fix tokenization and whitespaces for EL task.

* Add new registry handlers (with context length and arbitrary model names) for all REST models.

* Add sharding test with simple count task.

* Fix sharding algorithm.

* Add test with simple count task.

* Add context length as init arg in HF models.

* Fix tests. Don't stringify IO lists if sharded.

* Fix tests.

* Add NER sharding test.

* Add REL and sentiment sharding tests.

* Add summary sharding tests.

* Add EL sharding task. Fix bug in shard mapper.

* Fix REL error with RELExample parsing.

* Use regex for punctuation in REL conversion.

* Maintain custom doc attributes, incl. test.

* Filter merge warnings in textcat reduction.

* Fix custom doc data merging.

* Update spacy_llm/models/langchain/model.py

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Update spacy_llm/pipeline/llm.py

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Incorporate feedback.

* Move sharding compatibility warning to component constructor.

* Update spacy_llm/tasks/entity_linker/util.py

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Update spacy_llm/models/hf/base.py

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Incorporate feedback.

* Fix doc string

---------

Co-authored-by: Sofie Van Landeghem <[email protected]>
  • Loading branch information
rmitsch and svlandeg authored Dec 11, 2023
1 parent dbae0c9 commit a6515bf
Show file tree
Hide file tree
Showing 94 changed files with 3,441 additions and 1,113 deletions.
17 changes: 13 additions & 4 deletions spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ def __init__(
name: str,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
"""Initializes HF model instance.
query (Callable[[Any, Iterable[Any]], Iterable[Any]): Callable executing LLM prompts when
supplied with the `integration` object.
name (str): Name of HF model to load (without account name).
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
inference_config (Dict[Any, Any]): HF config for model run.
context_length (Optional[int]): Context length for this model. Necessary for sharding.
"""
self._name = name if self.hf_account in name else f"{self.hf_account}/{name}"
self._context_length = context_length
default_cfg_init, default_cfg_run = self.compile_default_configs()
self._config_init, self._config_run = default_cfg_init, default_cfg_run

Expand Down Expand Up @@ -73,10 +75,10 @@ def __init__(
self._model = self.init_model()

@abc.abstractmethod
def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]:
def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
"""Executes prompts on specified API.
prompts (Iterable[Any]): Prompts to execute.
RETURNS (Iterable[Any]): API responses.
prompts (Iterable[Iterable[Any]]): Prompts to execute per doc.
RETURNS (Iterable[Iterable[Any]]): API responses per doc.
"""

def _check_model(self) -> None:
Expand All @@ -93,6 +95,13 @@ def get_model_names(cls) -> Tuple[str, ...]:
"""
return tuple(str(arg) for arg in cls.MODEL_NAMES.__args__) # type: ignore[attr-defined]

@property
def context_length(self) -> Optional[int]:
"""Returns context length in number of tokens for this model.
RETURNS (Optional[int]): Max. number of tokens allowed in prompt for the current model.
"""
return self._context_length

@property
@abc.abstractmethod
def hf_account(self) -> str:
Expand Down
18 changes: 12 additions & 6 deletions spacy_llm/models/hf/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ def init_model(self) -> Any:
model=self._name, return_full_text=False, **self._config_init
)

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
"""Queries Dolly HF model.
pipeline (transformers.pipeline): Transformers pipeline to query.
prompts (Iterable[str]): Prompts to query Dolly model with.
RETURNS (Iterable[str]): Prompt responses.
prompts (Iterable[Iterable[str]]): Prompts per doc to query Dolly model with.
RETURNS (Iterable[Iterable[str]]): Prompt responses per doc.
"""
return [
self._model(pr, **self._config_run)[0]["generated_text"] for pr in prompts
[
self._model(pr, **self._config_run)[0]["generated_text"]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@property
Expand All @@ -52,12 +56,14 @@ def dolly_hf(
name: Dolly.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Dolly instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Dolly model. Has to be one of Dolly.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Dolly instance that can execute a set of prompts and return
the raw responses.
"""
return Dolly(name=name, config_init=config_init, config_run=config_run)
return Dolly(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
25 changes: 19 additions & 6 deletions spacy_llm/models/hf/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)
self._config_run["pad_token_id"] = self._tokenizer.pad_token_id
Expand All @@ -45,10 +51,15 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "tiiuae"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
return [
self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"]
for pr in prompts
[
self._model(pr, generation_config=self._hf_config_run)[0][
"generated_text"
]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@staticmethod
Expand All @@ -68,12 +79,14 @@ def falcon_hf(
name: Falcon.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Falcon instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
"""
return Falcon(name=name, config_init=config_init, config_run=config_run)
return Falcon(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
25 changes: 19 additions & 6 deletions spacy_llm/models/hf/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)
# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name,
Expand All @@ -39,10 +45,15 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "meta-llama"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
return [
self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"]
for pr in prompts
[
self._model(pr, generation_config=self._hf_config_run)[0][
"generated_text"
]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@staticmethod
Expand All @@ -55,12 +66,14 @@ def llama2_hf(
name: Llama2.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Llama 2 instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Llama 2 model. Has to be one of Llama2.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Llama2 instance that can execute a set of prompts and return
the raw responses.
"""
return Llama2(name=name, config_init=config_init, config_run=config_run)
return Llama2(
name=name, config_init=config_init, config_run=config_run, context_length=4096
)
65 changes: 41 additions & 24 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional

from confection import SimpleFrozenDict

Expand All @@ -15,10 +15,16 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
self._is_instruct = "instruct" in name
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)

Expand Down Expand Up @@ -48,43 +54,54 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "mistralai"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert callable(self._tokenizer)
assert hasattr(self._model, "generate")
assert hasattr(self._tokenizer, "batch_decode")
prompts = list(prompts)

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts
]
tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids]

return [
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
responses: List[List[str]] = []

for prompts_for_doc in prompts:
prompts_for_doc = list(prompts_for_doc)

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tp.to(self._model.device) for tp in tokenized_input_ids
]

responses.append(
[
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
)
for tok_ii in tokenized_input_ids
]
)
for tok_ii in tokenized_input_ids
]

return responses


@registry.llm_models("spacy.Mistral.v1")
def mistral_hf(
name: Mistral.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Mistral instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
"""
return Mistral(name=name, config_init=config_init, config_run=config_run)
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
)
56 changes: 36 additions & 20 deletions spacy_llm/models/hf/openllama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from confection import SimpleFrozenDict

Expand All @@ -20,9 +20,15 @@ def __init__(
name: str,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

def init_model(self) -> "transformers.AutoModelForCausalLM":
"""Sets up HF model and needed utilities.
Expand All @@ -43,24 +49,32 @@ def init_model(self) -> "transformers.AutoModelForCausalLM":

return model

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert callable(self._tokenizer)
tokenized_input_ids = [
self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts
]
tokenized_input_ids = [
tii.to(self._model.device) for tii in tokenized_input_ids
]

assert hasattr(self._model, "generate")
return [
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
responses: List[List[str]] = []

for prompts_for_doc in prompts:
tokenized_input_ids = [
self._tokenizer(prompt, return_tensors="pt").input_ids
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tii.to(self._model.device) for tii in tokenized_input_ids
]

assert hasattr(self._model, "generate")
responses.append(
[
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
)
for tii in tokenized_input_ids
]
)
for tii in tokenized_input_ids
]

return responses

@property
def hf_account(self) -> str:
Expand All @@ -83,12 +97,14 @@ def openllama_hf(
name: OpenLLaMA.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates OpenLLaMA instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the OpenLLaMA model. Has to be one of OpenLLaMA.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): OpenLLaMA instance that can execute a set of prompts and return
the raw responses.
"""
return OpenLLaMA(name=name, config_init=config_init, config_run=config_run)
return OpenLLaMA(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
Loading

0 comments on commit a6515bf

Please sign in to comment.