Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrarily long docs #332

Merged
merged 56 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
3aba660
Add context length info. Refactor BuiltinTask and models to facilitat…
rmitsch Oct 17, 2023
5699773
Merge branch 'develop' into feat/inf-doc-len
rmitsch Oct 17, 2023
4213372
Add token count estimator plumbing.
rmitsch Oct 17, 2023
f440ca4
Add plumbing for mapper and reducer.
rmitsch Oct 17, 2023
e47f762
Add ShardMapper prototype.
rmitsch Oct 18, 2023
89a5510
Integrating mapping into prompt generation workflow.
rmitsch Oct 19, 2023
086dec9
Update response parsing and component to support sharding (WIP).
rmitsch Oct 20, 2023
23718fc
Fix shard & prompt flow.
rmitsch Oct 27, 2023
7ce670d
Fix shard & prompt flow.
rmitsch Oct 27, 2023
0d75ea8
Remove todo comments.
rmitsch Oct 27, 2023
9da7098
Fix Anthropic, Cohere, NoOp model tests.
rmitsch Oct 27, 2023
0cb9afd
Merge branch 'develop' into feat/inf-doc-len
rmitsch Oct 30, 2023
f368412
Fix test_llm_pipe().
rmitsch Oct 31, 2023
b1f111d
Fix type checking test.
rmitsch Nov 3, 2023
44a2787
Fix span parsing tests.
rmitsch Nov 3, 2023
6d8cdc7
Fix internal tests.
rmitsch Nov 3, 2023
e712f41
Fix _CountTask.
rmitsch Nov 3, 2023
985fd68
Fix sentiment and summarization tasks and tests.
rmitsch Nov 3, 2023
98842a2
Fix Azure connection URL. Fix Model test pings.
rmitsch Nov 3, 2023
b54a3d9
Fix Lemma parsing.
rmitsch Nov 3, 2023
9bf365d
Start work on doc-to-shard property copying.
rmitsch Nov 3, 2023
dddfaab
Fix REL doc preprocessing.
rmitsch Nov 6, 2023
3af21b5
Remove comment on doc attribute handling during sharding, as this is …
rmitsch Nov 6, 2023
fee9ca7
Add reducer implementations.
rmitsch Nov 8, 2023
e508499
Implement outstanding task reducers.
rmitsch Nov 14, 2023
3218541
Resolve merge conflicts.
rmitsch Nov 14, 2023
c104387
Add shardable/non-shardable LLM task typing distinction. Add support …
rmitsch Nov 20, 2023
2c6d899
Merge branch 'develop' into feat/inf-doc-len
rmitsch Nov 21, 2023
2502c4d
Fix EL task.
rmitsch Nov 23, 2023
03055c5
Fix EL tokenization and highlighting partially.
rmitsch Nov 23, 2023
4e4a2cd
Fix tokenization and whitespaces for EL task.
rmitsch Nov 24, 2023
865acec
Fix merge conflicts.
rmitsch Nov 24, 2023
694d5da
Add new registry handlers (with context length and arbitrary model na…
rmitsch Nov 24, 2023
5295400
Add sharding test with simple count task.
rmitsch Nov 24, 2023
70e3643
Fix sharding algorithm.
rmitsch Nov 24, 2023
4321483
Add test with simple count task.
rmitsch Nov 27, 2023
ef6e738
Add context length as init arg in HF models.
rmitsch Nov 27, 2023
e3ff37d
Fix tests. Don't stringify IO lists if sharded.
rmitsch Nov 28, 2023
056730a
Fix tests.
rmitsch Nov 29, 2023
196c235
Add NER sharding test.
rmitsch Nov 29, 2023
1f51a4a
Add REL and sentiment sharding tests.
rmitsch Nov 29, 2023
e18b302
Add summary sharding tests.
rmitsch Nov 29, 2023
7c092ca
Add EL sharding task. Fix bug in shard mapper.
rmitsch Nov 29, 2023
358ba72
Fix REL error with RELExample parsing.
rmitsch Nov 29, 2023
0c96fb6
Use regex for punctuation in REL conversion.
rmitsch Nov 29, 2023
dc926bd
Maintain custom doc attributes, incl. test.
rmitsch Dec 1, 2023
5585174
Filter merge warnings in textcat reduction.
rmitsch Dec 1, 2023
1ae710c
Fix custom doc data merging.
rmitsch Dec 4, 2023
e94b356
Update spacy_llm/models/langchain/model.py
rmitsch Dec 7, 2023
e68f5d3
Update spacy_llm/pipeline/llm.py
rmitsch Dec 7, 2023
f40bc88
Incorporate feedback.
rmitsch Dec 7, 2023
ac0559d
Move sharding compatibility warning to component constructor.
rmitsch Dec 7, 2023
1763821
Update spacy_llm/tasks/entity_linker/util.py
rmitsch Dec 7, 2023
ae2e837
Update spacy_llm/models/hf/base.py
rmitsch Dec 7, 2023
63367fa
Incorporate feedback.
rmitsch Dec 7, 2023
e2f3ad8
Fix doc string
svlandeg Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions spacy_llm/models/hf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ def __init__(
name: str,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
"""Initializes HF model instance.
query (Callable[[Any, Iterable[Any]], Iterable[Any]): Callable executing LLM prompts when
supplied with the `integration` object.
name (str): Name of HF model to load (without account name).
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
inference_config (Dict[Any, Any]): HF config for model run.
context_length (Optional[int]): Context length for this model. Necessary for sharding.
"""
self._name = name if self.hf_account in name else f"{self.hf_account}/{name}"
self._context_length = context_length
default_cfg_init, default_cfg_run = self.compile_default_configs()
self._config_init, self._config_run = default_cfg_init, default_cfg_run

Expand Down Expand Up @@ -73,10 +75,10 @@ def __init__(
self._model = self.init_model()

@abc.abstractmethod
def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]:
def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
"""Executes prompts on specified API.
prompts (Iterable[Any]): Prompts to execute.
RETURNS (Iterable[Any]): API responses.
prompts (Iterable[Iterable[Any]]): Prompts to execute per doc.
RETURNS (Iterable[Iterable[Any]]): API responses per doc.
"""

def _check_model(self) -> None:
Expand All @@ -93,6 +95,13 @@ def get_model_names(cls) -> Tuple[str, ...]:
"""
return tuple(str(arg) for arg in cls.MODEL_NAMES.__args__) # type: ignore[attr-defined]

@property
def context_length(self) -> Optional[int]:
"""Returns context length in number of tokens for this model.
RETURNS (Optional[int]): Max. number of tokens allowed in prompt for the current model.
"""
return self._context_length

@property
@abc.abstractmethod
def hf_account(self) -> str:
Expand Down
18 changes: 12 additions & 6 deletions spacy_llm/models/hf/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ def init_model(self) -> Any:
model=self._name, return_full_text=False, **self._config_init
)

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
"""Queries Dolly HF model.
pipeline (transformers.pipeline): Transformers pipeline to query.
prompts (Iterable[str]): Prompts to query Dolly model with.
RETURNS (Iterable[str]): Prompt responses.
prompts (Iterable[Iterable[str]]): Prompts per doc to query Dolly model with.
RETURNS (Iterable[Iterable[str]]): Prompt responses per doc.
"""
return [
self._model(pr, **self._config_run)[0]["generated_text"] for pr in prompts
[
self._model(pr, **self._config_run)[0]["generated_text"]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@property
Expand All @@ -52,12 +56,14 @@ def dolly_hf(
name: Dolly.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Dolly instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Dolly model. Has to be one of Dolly.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Dolly instance that can execute a set of prompts and return
the raw responses.
"""
return Dolly(name=name, config_init=config_init, config_run=config_run)
return Dolly(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
25 changes: 19 additions & 6 deletions spacy_llm/models/hf/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)
self._config_run["pad_token_id"] = self._tokenizer.pad_token_id
Expand All @@ -45,10 +51,15 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "tiiuae"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
return [
self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"]
for pr in prompts
[
self._model(pr, generation_config=self._hf_config_run)[0][
"generated_text"
]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@staticmethod
Expand All @@ -68,12 +79,14 @@ def falcon_hf(
name: Falcon.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Falcon instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
"""
return Falcon(name=name, config_init=config_init, config_run=config_run)
return Falcon(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
25 changes: 19 additions & 6 deletions spacy_llm/models/hf/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)
# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name,
Expand All @@ -39,10 +45,15 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "meta-llama"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
return [
self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"]
for pr in prompts
[
self._model(pr, generation_config=self._hf_config_run)[0][
"generated_text"
]
for pr in prompts_for_doc
]
for prompts_for_doc in prompts
]

@staticmethod
Expand All @@ -55,12 +66,14 @@ def llama2_hf(
name: Llama2.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Llama 2 instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Llama 2 model. Has to be one of Llama2.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Llama2 instance that can execute a set of prompts and return
the raw responses.
"""
return Llama2(name=name, config_init=config_init, config_run=config_run)
return Llama2(
name=name, config_init=config_init, config_run=config_run, context_length=4096
)
65 changes: 41 additions & 24 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional

from confection import SimpleFrozenDict

Expand All @@ -15,10 +15,16 @@ def __init__(
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
self._is_instruct = "instruct" in name
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)

Expand Down Expand Up @@ -48,43 +54,54 @@ def init_model(self) -> Any:
def hf_account(self) -> str:
return "mistralai"

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert callable(self._tokenizer)
assert hasattr(self._model, "generate")
assert hasattr(self._tokenizer, "batch_decode")
prompts = list(prompts)

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts
]
tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids]

return [
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
responses: List[List[str]] = []

for prompts_for_doc in prompts:
prompts_for_doc = list(prompts_for_doc)

tokenized_input_ids = [
self._tokenizer(
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]",
return_tensors="pt",
).input_ids
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tp.to(self._model.device) for tp in tokenized_input_ids
]

responses.append(
[
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
)
for tok_ii in tokenized_input_ids
]
)
for tok_ii in tokenized_input_ids
]

return responses


@registry.llm_models("spacy.Mistral.v1")
def mistral_hf(
name: Mistral.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates Mistral instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
"""
return Mistral(name=name, config_init=config_init, config_run=config_run)
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
)
56 changes: 36 additions & 20 deletions spacy_llm/models/hf/openllama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from confection import SimpleFrozenDict

Expand All @@ -20,9 +20,15 @@ def __init__(
name: str,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: Optional[int],
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
super().__init__(name=name, config_init=config_init, config_run=config_run)
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

def init_model(self) -> "transformers.AutoModelForCausalLM":
"""Sets up HF model and needed utilities.
Expand All @@ -43,24 +49,32 @@ def init_model(self) -> "transformers.AutoModelForCausalLM":

return model

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert callable(self._tokenizer)
tokenized_input_ids = [
self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts
]
tokenized_input_ids = [
tii.to(self._model.device) for tii in tokenized_input_ids
]

assert hasattr(self._model, "generate")
return [
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
responses: List[List[str]] = []

for prompts_for_doc in prompts:
tokenized_input_ids = [
self._tokenizer(prompt, return_tensors="pt").input_ids
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tii.to(self._model.device) for tii in tokenized_input_ids
]

assert hasattr(self._model, "generate")
responses.append(
[
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
)
for tii in tokenized_input_ids
]
)
for tii in tokenized_input_ids
]

return responses

@property
def hf_account(self) -> str:
Expand All @@ -83,12 +97,14 @@ def openllama_hf(
name: OpenLLaMA.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[str]], Iterable[str]]:
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Generates OpenLLaMA instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the OpenLLaMA model. Has to be one of OpenLLaMA.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): OpenLLaMA instance that can execute a set of prompts and return
the raw responses.
"""
return OpenLLaMA(name=name, config_init=config_init, config_run=config_run)
return OpenLLaMA(
name=name, config_init=config_init, config_run=config_run, context_length=2048
)
Loading