diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 87209118..b8f8b7b7 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -17,6 +17,7 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): """Initializes HF model instance. query (Callable[[Any, Iterable[Any]], Iterable[Any]): Callable executing LLM prompts when @@ -24,9 +25,10 @@ def __init__( name (str): Name of HF model to load (without account name). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. - inference_config (Dict[Any, Any]): HF config for model run. + context_length (Optional[int]): Context length for this model. Necessary for sharding. """ self._name = name if self.hf_account in name else f"{self.hf_account}/{name}" + self._context_length = context_length default_cfg_init, default_cfg_run = self.compile_default_configs() self._config_init, self._config_run = default_cfg_init, default_cfg_run @@ -73,10 +75,10 @@ def __init__( self._model = self.init_model() @abc.abstractmethod - def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]: + def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: """Executes prompts on specified API. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): API responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute per doc. + RETURNS (Iterable[Iterable[Any]]): API responses per doc. """ def _check_model(self) -> None: @@ -93,6 +95,13 @@ def get_model_names(cls) -> Tuple[str, ...]: """ return tuple(str(arg) for arg in cls.MODEL_NAMES.__args__) # type: ignore[attr-defined] + @property + def context_length(self) -> Optional[int]: + """Returns context length in number of tokens for this model. + RETURNS (Optional[int]): Max. number of tokens allowed in prompt for the current model. + """ + return self._context_length + @property @abc.abstractmethod def hf_account(self) -> str: diff --git a/spacy_llm/models/hf/dolly.py b/spacy_llm/models/hf/dolly.py index 849f34bd..95b2bc9a 100644 --- a/spacy_llm/models/hf/dolly.py +++ b/spacy_llm/models/hf/dolly.py @@ -18,14 +18,18 @@ def init_model(self) -> Any: model=self._name, return_full_text=False, **self._config_init ) - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] """Queries Dolly HF model. pipeline (transformers.pipeline): Transformers pipeline to query. - prompts (Iterable[str]): Prompts to query Dolly model with. - RETURNS (Iterable[str]): Prompt responses. + prompts (Iterable[Iterable[str]]): Prompts per doc to query Dolly model with. + RETURNS (Iterable[Iterable[str]]): Prompt responses per doc. """ return [ - self._model(pr, **self._config_run)[0]["generated_text"] for pr in prompts + [ + self._model(pr, **self._config_run)[0]["generated_text"] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @property @@ -52,7 +56,7 @@ def dolly_hf( name: Dolly.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Dolly instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Dolly model. Has to be one of Dolly.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -60,4 +64,6 @@ def dolly_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Dolly instance that can execute a set of prompts and return the raw responses. """ - return Dolly(name=name, config_init=config_init, config_run=config_run) + return Dolly( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 2e18ac9d..68e05726 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -17,9 +17,15 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) self._config_run["pad_token_id"] = self._tokenizer.pad_token_id @@ -45,10 +51,15 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "tiiuae" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] return [ - self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"] - for pr in prompts + [ + self._model(pr, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @staticmethod @@ -68,7 +79,7 @@ def falcon_hf( name: Falcon.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Falcon instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -76,4 +87,6 @@ def falcon_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return the raw responses. """ - return Falcon(name=name, config_init=config_init, config_run=config_run) + return Falcon( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/llama2.py b/spacy_llm/models/hf/llama2.py index f03d00ee..eab32ceb 100644 --- a/spacy_llm/models/hf/llama2.py +++ b/spacy_llm/models/hf/llama2.py @@ -17,8 +17,14 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) # Instantiate GenerationConfig object from config dict. self._hf_config_run = transformers.GenerationConfig.from_pretrained( self._name, @@ -39,10 +45,15 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "meta-llama" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] return [ - self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"] - for pr in prompts + [ + self._model(pr, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @staticmethod @@ -55,7 +66,7 @@ def llama2_hf( name: Llama2.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Llama 2 instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Llama 2 model. Has to be one of Llama2.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -63,4 +74,6 @@ def llama2_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Llama2 instance that can execute a set of prompts and return the raw responses. """ - return Llama2(name=name, config_init=config_init, config_run=config_run) + return Llama2( + name=name, config_init=config_init, config_run=config_run, context_length=4096 + ) diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 56ae7be3..3c5039a2 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional from confection import SimpleFrozenDict @@ -15,10 +15,16 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_instruct = "instruct" in name - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) @@ -48,30 +54,39 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "mistralai" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) assert hasattr(self._model, "generate") assert hasattr(self._tokenizer, "batch_decode") - prompts = list(prompts) - - tokenized_input_ids = [ - self._tokenizer( - prompt if not self._is_instruct else f"[INST] {prompt} [/INST]", - return_tensors="pt", - ).input_ids - for prompt in prompts - ] - tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] - - return [ - self._tokenizer.decode( - self._model.generate( - input_ids=tok_ii, generation_config=self._hf_config_run - )[:, tok_ii.shape[1] :][0], - skip_special_tokens=True, + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + prompts_for_doc = list(prompts_for_doc) + + tokenized_input_ids = [ + self._tokenizer( + prompt if not self._is_instruct else f"[INST] {prompt} [/INST]", + return_tensors="pt", + ).input_ids + for prompt in prompts_for_doc + ] + tokenized_input_ids = [ + tp.to(self._model.device) for tp in tokenized_input_ids + ] + + responses.append( + [ + self._tokenizer.decode( + self._model.generate( + input_ids=tok_ii, generation_config=self._hf_config_run + )[:, tok_ii.shape[1] :][0], + skip_special_tokens=True, + ) + for tok_ii in tokenized_input_ids + ] ) - for tok_ii in tokenized_input_ids - ] + + return responses @registry.llm_models("spacy.Mistral.v1") @@ -79,7 +94,7 @@ def mistral_hf( name: Mistral.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Mistral instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -87,4 +102,6 @@ def mistral_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return the raw responses. """ - return Mistral(name=name, config_init=config_init, config_run=config_run) + return Mistral( + name=name, config_init=config_init, config_run=config_run, context_length=8000 + ) diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 8ceb5bbc..c18c46e8 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from confection import SimpleFrozenDict @@ -20,9 +20,15 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) def init_model(self) -> "transformers.AutoModelForCausalLM": """Sets up HF model and needed utilities. @@ -43,24 +49,32 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": return model - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) - tokenized_input_ids = [ - self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts - ] - tokenized_input_ids = [ - tii.to(self._model.device) for tii in tokenized_input_ids - ] - - assert hasattr(self._model, "generate") - return [ - self._tokenizer.decode( - self._model.generate(input_ids=tii, **self._config_run)[ - :, tii.shape[1] : - ][0], + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + tokenized_input_ids = [ + self._tokenizer(prompt, return_tensors="pt").input_ids + for prompt in prompts_for_doc + ] + tokenized_input_ids = [ + tii.to(self._model.device) for tii in tokenized_input_ids + ] + + assert hasattr(self._model, "generate") + responses.append( + [ + self._tokenizer.decode( + self._model.generate(input_ids=tii, **self._config_run)[ + :, tii.shape[1] : + ][0], + ) + for tii in tokenized_input_ids + ] ) - for tii in tokenized_input_ids - ] + + return responses @property def hf_account(self) -> str: @@ -83,7 +97,7 @@ def openllama_hf( name: OpenLLaMA.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates OpenLLaMA instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the OpenLLaMA model. Has to be one of OpenLLaMA.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -91,4 +105,6 @@ def openllama_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): OpenLLaMA instance that can execute a set of prompts and return the raw responses. """ - return OpenLLaMA(name=name, config_init=config_init, config_run=config_run) + return OpenLLaMA( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 34698e0e..5b0d29b7 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from confection import SimpleFrozenDict @@ -39,10 +39,16 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: Optional[int], ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_tuned = "tuned" in name - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) def init_model(self) -> "transformers.AutoModelForCausalLM": """Sets up HF model and needed utilities. @@ -66,32 +72,41 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": def hf_account(self) -> str: return "stabilityai" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) - tokenized_input_ids = [ - self._tokenizer(prompt, return_tensors="pt").input_ids - for prompt in ( - # Add prompt formatting for tuned model. - prompts - if not self._is_tuned - else [ - f"{StableLM._SYSTEM_PROMPT}<|USER|>{prompt}<|ASSISTANT|>" - for prompt in prompts + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + tokenized_input_ids = [ + self._tokenizer(prompt, return_tensors="pt").input_ids + for prompt in ( + # Add prompt formatting for tuned model. + prompts_for_doc + if not self._is_tuned + else [ + f"{StableLM._SYSTEM_PROMPT}<|USER|>{prompt}<|ASSISTANT|>" + for prompt in prompts_for_doc + ] + ) + ] + tokenized_input_ids = [ + tp.to(self._model.device) for tp in tokenized_input_ids + ] + + assert hasattr(self._model, "generate") + responses.append( + [ + self._tokenizer.decode( + self._model.generate(input_ids=tii, **self._config_run)[ + :, tii.shape[1] : + ][0], + skip_special_tokens=True, + ) + for tii in tokenized_input_ids ] ) - ] - tokenized_input_ids = [tp.to(self._model.device) for tp in tokenized_input_ids] - - assert hasattr(self._model, "generate") - return [ - self._tokenizer.decode( - self._model.generate(input_ids=tii, **self._config_run)[ - :, tii.shape[1] : - ][0], - skip_special_tokens=True, - ) - for tii in tokenized_input_ids - ] + + return responses @staticmethod def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -112,7 +127,7 @@ def stablelm_hf( name: StableLM.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates StableLM instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the StableLM model. Has to be one of StableLM.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. @@ -125,7 +140,5 @@ def stablelm_hf( f"Expected one of {StableLM.get_model_names()}, but received {name}." ) return StableLM( - name=name, - config_init=config_init, - config_run=config_run, + name=name, config_init=config_init, config_run=config_run, context_length=4096 ) diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 2e4be55f..e92654fb 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -17,17 +17,23 @@ def __init__( name: str, api: str, config: Dict[Any, Any], - query: Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]], + query: Callable[ + ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ], + context_length: Optional[int], ): """Initializes model instance for integration APIs. name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[langchain.llms.BaseLLM, Iterable[Any]], Iterable[Any]]): Callable executing LLM prompts when - supplied with the `integration` object. + query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing + LLM prompts when supplied with the model instance. + context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context + length provided, prompts can't be sharded. """ self._langchain_model = LangChain._init_langchain_model(name, api, config) self.query = query + self._context_length = context_length self._check_installation() @classmethod @@ -73,23 +79,24 @@ def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: """ return getattr(langchain.llms, "type_to_cls_dict") - def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]: + def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: """Executes prompts on specified API. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): API responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute. + RETURNS (Iterable[Iterable[Any]]): API responses. """ return self.query(self._langchain_model, prompts) @staticmethod def query_langchain( - model: "langchain.llms.BaseLLM", prompts: Iterable[Any] - ) -> Iterable[Any]: + model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]] + ) -> Iterable[Iterable[Any]]: """Query LangChain model naively. model (langchain.llms.BaseLLM): LangChain model. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): LLM responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute. + RETURNS (Iterable[Iterable[Any]]): LLM responses. """ - return [model(pr) for pr in prompts] + assert callable(model) + return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts] @staticmethod def _check_installation() -> None: @@ -105,17 +112,22 @@ def _langchain_model_maker(class_id: str): def langchain_model( name: str, query: Optional[ - Callable[["langchain.llms.BaseLLM", Iterable[str]], Iterable[str]] + Callable[ + ["langchain.llms.BaseLLM", Iterable[Iterable[str]]], + Iterable[Iterable[str]], + ] ] = None, config: Dict[Any, Any] = SimpleFrozenDict(), + context_length: Optional[int] = None, langchain_class_id: str = class_id, - ) -> Optional[Callable[[Iterable[Any]], Iterable[Any]]]: + ) -> Optional[Callable[[Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]]: try: return LangChain( name=name, api=langchain_class_id, config=config, query=query_langchain() if query is None else query, + context_length=context_length, ) except ImportError as err: raise ValueError( @@ -125,6 +137,13 @@ def langchain_model( return langchain_model + @property + def context_length(self) -> Optional[int]: + """Returns context length in number of tokens for this model. + RETURNS (Optional[int]): Max. number of tokens in allowed in prompt for the current model. None if unknown. + """ + return self._context_length + @staticmethod def register_models() -> None: """Registers APIs supported by langchain (one API is registered as one model). @@ -148,10 +167,12 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( - Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]] + Callable[ + ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ] ): """Returns query Callable for LangChain. - RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on - the specified LangChain model. + RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing + simple prompts on the specified LangChain model. """ return LangChain.query_langchain diff --git a/spacy_llm/models/rest/anthropic/__init__.py b/spacy_llm/models/rest/anthropic/__init__.py index 745a0fbe..ca6c99b8 100644 --- a/spacy_llm/models/rest/anthropic/__init__.py +++ b/spacy_llm/models/rest/anthropic/__init__.py @@ -1,15 +1,26 @@ from .model import Anthropic, Endpoints -from .registry import anthropic_claude_1, anthropic_claude_1_0, anthropic_claude_1_2 -from .registry import anthropic_claude_1_3, anthropic_claude_instant_1 -from .registry import anthropic_claude_instant_1_1 +from .registry import anthropic_claude_1, anthropic_claude_1_0, anthropic_claude_1_0_v2 +from .registry import anthropic_claude_1_2, anthropic_claude_1_2_v2 +from .registry import anthropic_claude_1_3, anthropic_claude_1_3_v2 +from .registry import anthropic_claude_1_v2, anthropic_claude_2, anthropic_claude_2_v2 +from .registry import anthropic_claude_instant_1, anthropic_claude_instant_1_1 +from .registry import anthropic_claude_instant_1_1_v2, anthropic_claude_instant_1_v2 __all__ = [ "Anthropic", "Endpoints", "anthropic_claude_1", + "anthropic_claude_1_v2", "anthropic_claude_1_0", + "anthropic_claude_1_0_v2", "anthropic_claude_1_2", + "anthropic_claude_1_2_v2", "anthropic_claude_1_3", + "anthropic_claude_1_3_v2", "anthropic_claude_instant_1", + "anthropic_claude_instant_1_v2", "anthropic_claude_instant_1_1", + "anthropic_claude_instant_1_1_v2", + "anthropic_claude_2", + "anthropic_claude_2_v2", ] diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index 774d0e83..269b3209 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -40,7 +40,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: # Execute a dummy prompt. If the API setup is incorrect, we should fail at initialization time. try: - self(["test"]) + self([["test"]]) except ValueError as err: if "authentication_error" in str(err): warnings.warn( @@ -50,60 +50,91 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "model": self._name, "anthropic-version": self._config.get("anthropic-version", "2023-06-01"), "Content-Type": "application/json", } + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, + ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + error = res_content.get("error", {}) + error_msg = f"Request to Anthropic API failed: {error}" + if error["type"] == "not_found_error": + error_msg += f". Ensure that the selected model ({self._name}) is supported by the API." + raise ValueError(error_msg) from ex + response = r.json() + + # c.f. https://console.anthropic.com/docs/api/errors + if "error" in response: + if self._strict: + raise ValueError(f"API call failed: {response}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + return response + + # Anthropic API currently doesn't accept batch prompts, so we're making + # a request for each iteration. This approach can be prone to rate limit + # errors. In practice, you can adjust _max_request_time so that the + # timeout is larger. + responses = [ + _request( + {"prompt": f"{SystemPrompt.HUMAN} {prompt}{SystemPrompt.ASST}"} + ) + for prompt in prompts_for_doc + ] - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self._endpoint, - headers=headers, - json={**json_data, **self._config, "model": self._name}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - error = res_content.get("error", {}) - error_msg = f"Request to Anthropic API failed: {error}" - if error["type"] == "not_found_error": - error_msg += f". Ensure that the selected model ({self._name}) is supported by the API." - raise ValueError(error_msg) from ex - response = r.json() - - # c.f. https://console.anthropic.com/docs/api/errors - if "error" in response: - if self._strict: - raise ValueError(f"API call failed: {response}.") + for response in responses: + if "completion" in response: + api_responses.append(response["completion"]) else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} - return response - - # Anthropic API currently doesn't accept batch prompts, so we're making - # a request for each iteration. This approach can be prone to rate limit - # errors. In practice, you can adjust _max_request_time so that the - # timeout is larger. - responses = [ - _request({"prompt": f"{SystemPrompt.HUMAN} {prompt}{SystemPrompt.ASST}"}) - for prompt in prompts - ] - - for response in responses: - if "completion" in response: - api_responses.append(response["completion"]) - else: - api_responses.append(srsly.json_dumps(response)) - - assert len(api_responses) == len(prompts) - return api_responses + api_responses.append(srsly.json_dumps(response)) + + assert len(api_responses) == len(prompts_for_doc) + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + # claude-2 + "claude-2": 100000, + "claude-2-100k": 100000, + # claude-1 + "claude-1": 100000, + "claude-1-100k": 100000, + # claude-instant-1 + "claude-instant-1": 100000, + "claude-instant-1-100k": 100000, + # claude-instant-1.1 + "claude-instant-1.1": 100000, + "claude-instant-1.1-100k": 100000, + # claude-1.3 + "claude-1.3": 100000, + "claude-1.3-100k": 100000, + # others + "claude-1.0": 100000, + "claude-1.2": 100000, + } diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index 504da15a..dc44eb7e 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,6 +7,43 @@ from .model import Anthropic, Endpoints +@registry.llm_models("spacy.Claude-2.v2") +def anthropic_claude_2_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-2", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Anthropic: + """Returns Anthropic instance for 'claude-2' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e.g. "claude-2" or "claude-2-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-2' model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-2.v1") def anthropic_claude_2( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -15,7 +52,7 @@ def anthropic_claude_2( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-2' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-2", "claude-2-100k"]): Model to use. @@ -27,8 +64,7 @@ def anthropic_claude_2( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. """ return Anthropic( name=name, @@ -38,6 +74,44 @@ def anthropic_claude_2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-1.v2") +def anthropic_claude_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1" or "claude-1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, ) @@ -49,7 +123,7 @@ def anthropic_claude_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1", "claude-1-100k"]): Model to use. @@ -61,8 +135,7 @@ def anthropic_claude_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. """ return Anthropic( name=name, @@ -72,6 +145,44 @@ def anthropic_claude_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-instant-1.v2") +def anthropic_claude_instant_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-instant-1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-instant-1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-instant-1" or "claude-instant-1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, ) @@ -85,7 +196,7 @@ def anthropic_claude_instant_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-instant-1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-instant-1", "claude-instant-1-100k"]): Model to use. @@ -97,8 +208,7 @@ def anthropic_claude_instant_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-instant-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1'. """ return Anthropic( name=name, @@ -108,6 +218,44 @@ def anthropic_claude_instant_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-instant-1-1.v2") +def anthropic_claude_instant_1_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-instant-1.1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-instant-1.1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-instant-1.1" or "claude-instant-1.1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1.1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, ) @@ -121,7 +269,7 @@ def anthropic_claude_instant_1_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-instant-1.1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-instant-1.1", "claude-instant-1.1-100k"]): Model to use. @@ -133,8 +281,7 @@ def anthropic_claude_instant_1_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-instant-1.1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1.1' model. """ return Anthropic( name=name, @@ -144,6 +291,44 @@ def anthropic_claude_instant_1_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-1-0.v2") +def anthropic_claude_1_0_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.0", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.0' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1.0". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.0'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, ) @@ -155,7 +340,7 @@ def anthropic_claude_1_0( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.0' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.0"]): Model to use. @@ -167,8 +352,44 @@ def anthropic_claude_1_0( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.0' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.0' model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-1-2.v2") +def anthropic_claude_1_2_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.2", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.2' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1.2". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.2'. """ return Anthropic( name=name, @@ -178,6 +399,7 @@ def anthropic_claude_1_0( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -189,7 +411,7 @@ def anthropic_claude_1_2( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.2' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.2"]): Model to use. @@ -201,8 +423,44 @@ def anthropic_claude_1_2( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.2' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.2' model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Claude-1-3.v2") +def anthropic_claude_1_3_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.3", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.3' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model variant to use, e. g. "claude-1.3" or "claude-1.3-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.3' model. """ return Anthropic( name=name, @@ -212,6 +470,7 @@ def anthropic_claude_1_2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -223,7 +482,7 @@ def anthropic_claude_1_3( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.3' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.3", "claude-1.3-100k"]): Model variant to use. @@ -235,8 +494,7 @@ def anthropic_claude_1_3( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.3' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.3' model. """ return Anthropic( name=name, @@ -246,4 +504,5 @@ def anthropic_claude_1_3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/azure/__init__.py b/spacy_llm/models/rest/azure/__init__.py index 142972a5..f59e8679 100644 --- a/spacy_llm/models/rest/azure/__init__.py +++ b/spacy_llm/models/rest/azure/__init__.py @@ -1,4 +1,4 @@ from .model import AzureOpenAI -from .registry import azure_openai +from .registry import azure_openai, azure_openai_v2 -__all__ = ["AzureOpenAI", "azure_openai"] +__all__ = ["AzureOpenAI", "azure_openai", "azure_openai_v2"] diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index d8f433d6..32adc0bb 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized +from typing import Any, Dict, Iterable, List, Optional, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -18,6 +18,7 @@ class ModelType(str, Enum): class AzureOpenAI(REST): def __init__( self, + deployment_name: str, name: str, endpoint: str, config: Dict[Any, Any], @@ -26,10 +27,12 @@ def __init__( interval: float, max_request_time: float, model_type: ModelType, + context_length: Optional[int], api_version: str = "2023-05-15", ): self._model_type = model_type self._api_version = api_version + self._deployment_name = deployment_name super().__init__( name=name, endpoint=endpoint, @@ -38,6 +41,7 @@ def __init__( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @property @@ -48,7 +52,7 @@ def endpoint(self) -> str: return ( self._endpoint + ("" if self._endpoint.endswith("/") else "/") - + f"openai/deployments/{self._name}/{'' if self._model_type == ModelType.COMPLETION else 'chat/'}" + + f"openai/deployments/{self._deployment_name}/{'' if self._model_type == ModelType.COMPLETION else 'chat/'}" f"completions" ) @@ -71,79 +75,108 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["test"]) + self([["test"]]) except ValueError as err: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self.endpoint, - headers=headers, - json={**json_data, **self._config}, - timeout=self._max_request_time, - params={"api-version": self._api_version}, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - raise ValueError( - f"Request to Azure OpenAI API failed: " - f"{res_content.get('error', {}).get('message', str(res_content))}" - ) from ex - responses = r.json() - - # todo check if this is the same - if "error" in responses: - if self._strict: - raise ValueError(f"API call failed: {responses}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(responses)] * len(prompts)} - - return responses - - # The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests. - # https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai - - if self._model_type == ModelType.CHAT: - # Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model - # yet. - for prompt in prompts: - responses = _request( - {"messages": [{"role": "user", "content": prompt}]} - ) - if "error" in responses: - return responses["error"] - - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append( - response.get("message", {}).get( - "content", srsly.json_dumps(response) - ) + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self.endpoint, + headers=headers, + json={**json_data, **self._config}, + timeout=self._max_request_time, + params={"api-version": self._api_version}, ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + raise ValueError( + f"Request to Azure OpenAI API failed: " + f"{res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() - elif self._model_type == ModelType.COMPLETION: - for prompt in prompts: - responses = _request({"prompt": prompt}) if "error" in responses: - return responses["error"] + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + # The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests. + # https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai + + if self._model_type == ModelType.CHAT: + # Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model + # yet. + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) + ) - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append(response.get("text", srsly.json_dumps(response))) + elif self._model_type == ModelType.COMPLETION: + for prompt in prompts_for_doc: + responses = _request({"prompt": prompt}) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("text", srsly.json_dumps(response)) + ) - return api_responses + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + # gpt-4 + "gpt-4": 8192, + "gpt-4-32k": 32768, + # gpt-3.5 + "gpt-3.5-turbo": 4097, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-instruct": 4097, + # text-davinci + "text-davinci-002": 4097, + "text-davinci-003": 4097, + # others + "code-davinci-002": 8001, + "text-curie-001": 2049, + "text-babbage-001": 2049, + "text-ada-001": 2049, + } diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index 9d88e466..38df5cb9 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -8,8 +8,67 @@ _DEFAULT_TEMPERATURE = 0.0 +@registry.llm_models("spacy.Azure.v2") +def azure_openai_v2( + deployment_name: str, + name: str, + base_url: str, + model_type: ModelType, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + strict: bool = AzureOpenAI.DEFAULT_STRICT, + max_tries: int = AzureOpenAI.DEFAULT_MAX_TRIES, + interval: float = AzureOpenAI.DEFAULT_INTERVAL, + max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, + api_version: str = "2023-05-15", + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Azure OpenAI instance for models deployed on Azure's OpenAI service using REST to prompt API. + + Docs on OpenAI models supported by Azure: + https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#model-summary-table-and-region-availability. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + deployment_name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the + model used by that deployment, as deployment names in Azure OpenAI can be arbitrary. + name (str): Name of the model used by this deployment. This is required to infer the context length that can be + assumed for prompting. + endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like + "https://{prefix}.openai.azure.com/". + model_type (ModelType): Whether the deployed model is a text completetion model (e. g. + text-davinci-003) or a chat model (e. g. gpt-4). + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + api_version (str): API version to use. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (AzureOpenAI): AzureOpenAI instance for deployed model. + + DOCS: https://spacy.io/api/large-language-models#models + """ + return AzureOpenAI( + deployment_name=deployment_name, + name=name, + endpoint=base_url, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + api_version=api_version, + model_type=model_type, + context_length=context_length, + ) + + @registry.llm_models("spacy.Azure.v1") def azure_openai( + deployment_name: str, name: str, base_url: str, model_type: ModelType, @@ -19,12 +78,17 @@ def azure_openai( interval: float = AzureOpenAI.DEFAULT_INTERVAL, max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, api_version: str = "2023-05-15", -) -> Callable[[Iterable[str]], Iterable[str]]: - """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Azure OpenAI instance for models deployed on Azure's OpenAI service using REST to prompt API. + + Docs on OpenAI models supported by Azure: + https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#model-summary-table-and-region-availability. config (Dict[Any, Any]): LLM config passed on to the model's initialization. - name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the model used by - that deployment, as deployment names in Azure OpenAI can be arbitrary. + deployment_name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the + model used by that deployment, as deployment names in Azure OpenAI can be arbitrary. + name (str): Name of the model used by this deployment. This is required to infer the context length that can be + assumed for prompting. endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like "https://{prefix}.openai.azure.com/". model_type (ModelType): Whether the deployed model is a text completetion model (e. g. @@ -38,11 +102,12 @@ def azure_openai( at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. api_version (str): API version to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + RETURNS (AzureOpenAI): AzureOpenAI instance for deployed model. DOCS: https://spacy.io/api/large-language-models#models """ return AzureOpenAI( + deployment_name=deployment_name, name=name, endpoint=base_url, config=config, @@ -52,4 +117,5 @@ def azure_openai( max_request_time=max_request_time, api_version=api_version, model_type=model_type, + context_length=None, ) diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index f54f90ac..df089961 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -33,6 +33,7 @@ def __init__( max_tries: int, interval: float, max_request_time: float, + context_length: Optional[int], ): """Initializes new instance of REST-based model. name (str): Model name. @@ -47,6 +48,8 @@ def __init__( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context + length natively provided by spacy-llm. """ self._name = name self._endpoint = endpoint @@ -56,6 +59,7 @@ def __init__( self._interval = interval self._max_request_time = max_request_time self._credentials = self.credentials + self._context_length = context_length assert self._max_tries >= 1 assert self._interval > 0 @@ -64,12 +68,30 @@ def __init__( self._verify_auth() @abc.abstractmethod - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: """Executes prompts on specified API. - prompts (Iterable[str]): Prompts to execute. - RETURNS (Iterable[str]): API responses. + prompts (Iterable[Iterable[str]]): Prompts to execute. + RETURNS (Iterable[Iterable[str]]): API responses. """ + @staticmethod + @abc.abstractmethod + def _get_context_lengths() -> Dict[str, int]: + """Get context lengths per model name. + RETURNS (Dict[str, int]): Dict with model name -> context length. + """ + + @property + def context_length(self) -> Optional[int]: + """Returns context length in number of tokens for this model. + RETURNS (Optional[int]): Max. number of tokens in allowed in prompt for the current model. None if unknown. + """ + return ( + self._context_length + if self._context_length + else self._get_context_lengths().get(self._name, None) # type: ignore[arg-type] + ) + @property @abc.abstractmethod def credentials(self) -> Dict[str, str]: diff --git a/spacy_llm/models/rest/cohere/__init__.py b/spacy_llm/models/rest/cohere/__init__.py index f5319ec4..8ce0b194 100644 --- a/spacy_llm/models/rest/cohere/__init__.py +++ b/spacy_llm/models/rest/cohere/__init__.py @@ -1,4 +1,4 @@ from .model import Cohere, Endpoints -from .registry import cohere_command +from .registry import cohere_command, cohere_command_v2 -__all__ = ["Cohere", "Endpoints", "cohere_command"] +__all__ = ["Cohere", "Endpoints", "cohere_command", "cohere_command_v2"] diff --git a/spacy_llm/models/rest/cohere/model.py b/spacy_llm/models/rest/cohere/model.py index 58ba3231..55cd78c1 100644 --- a/spacy_llm/models/rest/cohere/model.py +++ b/spacy_llm/models/rest/cohere/model.py @@ -29,7 +29,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["test"]) + self([["test"]]) except ValueError as err: if "invalid api token" in str(err): warnings.warn( @@ -39,15 +39,17 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", "Accept": "application/json", } + all_api_responses: List[List[str]] = [] - api_responses: List[str] = [] - prompts = list(prompts) + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: r = self.retry( @@ -88,15 +90,18 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: if self._strict: raise ValueError(f"API call failed: {response}.") else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + return response # Cohere API currently doesn't accept batch prompts, so we're making # a request for each iteration. This approach can be prone to rate limit # errors. In practice, you can adjust _max_request_time so that the # timeout is larger. - responses = [_request({"prompt": prompt}) for prompt in prompts] + responses = [_request({"prompt": prompt}) for prompt in prompts_for_doc] for response in responses: if "generations" in response: for result in response["generations"]: @@ -110,4 +115,16 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: api_responses.append(srsly.json_dumps(response)) else: api_responses.append(srsly.json_dumps(response)) - return api_responses + + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "command": 4096, + "command-light": 4096, + "command-light-nightly": 4096, + "command-nightly": 4096, + } diff --git a/spacy_llm/models/rest/cohere/registry.py b/spacy_llm/models/rest/cohere/registry.py index 3279bf4f..79c711e1 100644 --- a/spacy_llm/models/rest/cohere/registry.py +++ b/spacy_llm/models/rest/cohere/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,6 +7,43 @@ from .model import Cohere, Endpoints +@registry.llm_models("spacy.Command.v2") +def cohere_command_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "command", + strict: bool = Cohere.DEFAULT_STRICT, + max_tries: int = Cohere.DEFAULT_MAX_TRIES, + interval: float = Cohere.DEFAULT_INTERVAL, + max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Cohere instance for 'command' model using REST to prompt API. + name (str): Name of model to use, e. g. "command" or "command-light". + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Cohere): Cohere instance for 'command' model. + """ + return Cohere( + name=name, + endpoint=Endpoints.COMPLETION.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Command.v1") def cohere_command( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -17,7 +54,7 @@ def cohere_command( max_tries: int = Cohere.DEFAULT_MAX_TRIES, interval: float = Cohere.DEFAULT_INTERVAL, max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Cohere instance for 'command' model using REST to prompt API. name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. @@ -29,7 +66,7 @@ def cohere_command( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API. + RETURNS (Cohere): Cohere instance for 'command' model. """ return Cohere( name=name, @@ -39,4 +76,5 @@ def cohere_command( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index 31d830b8..5364438a 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -1,3 +1,4 @@ +import sys import time from typing import Dict, Iterable @@ -20,6 +21,7 @@ def __init__(self): max_tries=1, interval=1, max_request_time=1, + context_length=None, ) @property @@ -29,7 +31,11 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: pass - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # Assume time penalty for API calls. time.sleep(NoOpModel._CALL_TIMEOUT) - return [_NOOP_RESPONSE] * len(list(prompts)) + return [[_NOOP_RESPONSE]] * len(list(prompts)) + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return {"NoOp": sys.maxsize} diff --git a/spacy_llm/models/rest/noop/registry.py b/spacy_llm/models/rest/noop/registry.py index bd393776..4050906b 100644 --- a/spacy_llm/models/rest/noop/registry.py +++ b/spacy_llm/models/rest/noop/registry.py @@ -5,7 +5,7 @@ @registry.llm_models("spacy.NoOp.v1") -def noop() -> Callable[[Iterable[str]], Iterable[str]]: +def noop() -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns NoOpModel. RETURNS (Callable[[Iterable[str]], Iterable[str]]]): NoOp model instance for test purposes. """ diff --git a/spacy_llm/models/rest/openai/__init__.py b/spacy_llm/models/rest/openai/__init__.py index e1782596..3cde8bef 100644 --- a/spacy_llm/models/rest/openai/__init__.py +++ b/spacy_llm/models/rest/openai/__init__.py @@ -2,10 +2,11 @@ from .registry import openai_ada, openai_ada_v2, openai_babbage, openai_babbage_v2 from .registry import openai_code_davinci, openai_code_davinci_v2, openai_curie from .registry import openai_curie_v2, openai_davinci, openai_davinci_v2 -from .registry import openai_gpt_3_5, openai_gpt_3_5_v2, openai_gpt_4, openai_gpt_4_v2 -from .registry import openai_text_ada, openai_text_ada_v2, openai_text_babbage -from .registry import openai_text_babbage_v2, openai_text_curie, openai_text_curie_v2 -from .registry import openai_text_davinci, openai_text_davinci_v2 +from .registry import openai_gpt_3_5, openai_gpt_3_5_v2, openai_gpt_3_5_v3 +from .registry import openai_gpt_4, openai_gpt_4_v2, openai_gpt_4_v3, openai_text_ada +from .registry import openai_text_ada_v2, openai_text_babbage, openai_text_babbage_v2 +from .registry import openai_text_curie, openai_text_curie_v2, openai_text_davinci +from .registry import openai_text_davinci_v2, openai_text_davinci_v3 __all__ = [ "OpenAI", @@ -22,8 +23,10 @@ "openai_davinci_v2", "openai_gpt_3_5", "openai_gpt_3_5_v2", + "openai_gpt_3_5_v3", "openai_gpt_4", "openai_gpt_4_v2", + "openai_gpt_4_v3", "openai_text_ada", "openai_text_ada_v2", "openai_text_babbage", @@ -32,4 +35,5 @@ "openai_text_curie_v2", "openai_text_davinci", "openai_text_davinci_v2", + "openai_text_davinci_v3", ] diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index 8fa9dc20..b8bbdae3 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -74,69 +74,106 @@ def _verify_auth(self) -> None: f"The specified model '{self._name}' is not available. Choices are: {sorted(set(models))}" ) - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self._endpoint, - headers=headers, - json={**json_data, **self._config, "model": self._name}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - raise ValueError( - f"Request to OpenAI API failed: {res_content.get('error', {}).get('message', str(res_content))}" - ) from ex - responses = r.json() - - if "error" in responses: - if self._strict: - raise ValueError(f"API call failed: {responses}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(responses)] * len(prompts)} - - return responses - - if self._endpoint == Endpoints.CHAT: - # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests. - for prompt in prompts: - responses = _request( - {"messages": [{"role": "user", "content": prompt}]} + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, ) - if "error" in responses: - return responses["error"] + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + raise ValueError( + f"Request to OpenAI API failed: {res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append( - response.get("message", {}).get( - "content", srsly.json_dumps(response) + if "error" in responses: + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + if self._endpoint == Endpoints.CHAT: + # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual + # requests. + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) ) - ) - - elif self._endpoint == Endpoints.NON_CHAT: - responses = _request({"prompt": prompts}) - if "error" in responses: - return responses["error"] - assert len(responses["choices"]) == len(prompts) - - for response in responses["choices"]: - if "text" in response: - api_responses.append(response["text"]) - else: - api_responses.append(srsly.json_dumps(response)) - return api_responses + elif self._endpoint == Endpoints.NON_CHAT: + responses = _request({"prompt": prompts_for_doc}) + if "error" in responses: + return responses["error"] + assert len(responses["choices"]) == len(prompts_for_doc) + + for response in responses["choices"]: + if "text" in response: + api_responses.append(response["text"]) + else: + api_responses.append(srsly.json_dumps(response)) + + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + # gpt-4 + "gpt-4": 8192, + "gpt-4-0314": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, + # gpt-3.5 + "gpt-3.5-turbo": 4097, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 4097, + "gpt-3.5-turbo-0613-16k": 16385, + "gpt-3.5-turbo-instruct": 4097, + # text-davinci + "text-davinci-002": 4097, + "text-davinci-003": 4097, + # others + "code-davinci-002": 8001, + "text-curie-001": 2049, + "text-babbage-001": 2049, + "text-ada-001": 2049, + "davinci": 2049, + "curie": 2049, + "babbage": 2049, + "ada": 2049, + } diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 82436e4a..772a4579 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Dict, Optional from confection import SimpleFrozenDict @@ -24,18 +24,21 @@ @registry.llm_models("spacy.GPT-4.v3") def openai_gpt_4_v3( config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), - name: str = "gpt-4", # noqa: F722 + name: str = "gpt-4", strict: bool = OpenAI.DEFAULT_STRICT, max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: + context_length: Optional[int] = None, +) -> OpenAI: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4', "gpt-4-1106-preview", .... - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI instance for 'gpt-4' model. DOCS: https://spacy.io/api/large-language-models#models """ @@ -47,6 +50,7 @@ def openai_gpt_4_v3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -60,12 +64,12 @@ def openai_gpt_4_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by default. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + RETURNS (OpenAI): OpenAI instance for 'gpt-4' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -77,6 +81,7 @@ def openai_gpt_4_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -90,13 +95,13 @@ def openai_gpt_4( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by default. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + RETURNS (OpenAI): OpenAI instance for 'gpt-4' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -108,6 +113,7 @@ def openai_gpt_4( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -119,13 +125,16 @@ def openai_gpt_3_5_v3( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: + context_length: Optional[int] = None, +) -> OpenAI: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (str): Name of model to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-3.5', "gpt-3.5-turbo", .... - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI instance for 'gpt-3.5' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -139,6 +148,7 @@ def openai_gpt_3_5_v3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -156,14 +166,14 @@ def openai_gpt_3_5_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Literal[ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k", "gpt-3.5-turbo-instruct" ]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model + RETURNS (OpenAI): OpenAI instance for 'gpt-3.5' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -177,6 +187,7 @@ def openai_gpt_3_5_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -194,14 +205,14 @@ def openai_gpt_3_5( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Literal[ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k", "gpt-3.5-turbo-instruct" ]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model + RETURNS (OpenAI): OpenAI instance for 'gpt-3.5' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -215,6 +226,41 @@ def openai_gpt_3_5( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, + ) + + +@registry.llm_models("spacy.Text-Davinci.v3") +def openai_text_davinci_v3( + config: Dict[Any, Any] = SimpleFrozenDict( + max_tokens=1000, temperature=_DEFAULT_TEMPERATURE + ), + name: str = "text-davinci-003", + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> OpenAI: + """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Name of model to use, e. g. "text-davinci-002" or "text-davinci-003". + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI instance for 'text-davinci' model + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=Endpoints.NON_CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, ) @@ -230,12 +276,12 @@ def openai_text_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-davinci-002", "text-davinci-003"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-davinci' model + RETURNS (OpenAI): OpenAI instance for 'text-davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -247,6 +293,7 @@ def openai_text_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -260,12 +307,12 @@ def openai_text_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-davinci-002", "text-davinci-003"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-davinci' model + RETURNS (OpenAI): OpenAI instance for 'text-davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -277,6 +324,7 @@ def openai_text_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -290,12 +338,12 @@ def openai_code_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'code-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["code-davinci-002"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'code-davinci' model + RETURNS (OpenAI): OpenAI instance for 'code-davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -307,6 +355,7 @@ def openai_code_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -318,12 +367,12 @@ def openai_code_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'code-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["code-davinci-002"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'code-davinci' model + RETURNS (OpenAI): OpenAI instance for 'code-davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -335,6 +384,7 @@ def openai_code_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -348,12 +398,12 @@ def openai_text_curie_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-curie-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-curie' model + RETURNS (OpenAI): OpenAI instance for 'text-curie' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -365,6 +415,7 @@ def openai_text_curie_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -376,12 +427,12 @@ def openai_text_curie( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-curie-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-curie' model + RETURNS (OpenAI): OpenAI instance for 'text-curie' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -393,6 +444,7 @@ def openai_text_curie( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -406,12 +458,12 @@ def openai_text_babbage_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-babbage-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-babbage' model + RETURNS (OpenAI): OpenAI instance for 'text-babbage' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -423,6 +475,7 @@ def openai_text_babbage_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -434,12 +487,12 @@ def openai_text_babbage( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-babbage-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-babbage' model + RETURNS (OpenAI): OpenAI instance for 'text-babbage' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -451,6 +504,7 @@ def openai_text_babbage( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -464,12 +518,12 @@ def openai_text_ada_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-ada-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'text-ada' model + RETURNS (OpenAI): Anthropic instance for 'text-ada' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -481,6 +535,7 @@ def openai_text_ada_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -492,12 +547,12 @@ def openai_text_ada( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'text-ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["text-ada-001"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'text-ada' model + RETURNS (OpenAI): OpenAI instance for 'text-ada' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -509,6 +564,7 @@ def openai_text_ada( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -522,12 +578,12 @@ def openai_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["davinci"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'davinci' model + RETURNS (OpenAI): OpenAI instance for 'davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -539,6 +595,7 @@ def openai_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -550,12 +607,12 @@ def openai_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["davinci"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'davinci' model + RETURNS (OpenAI): OpenAI instance for 'davinci' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -567,6 +624,7 @@ def openai_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -580,12 +638,12 @@ def openai_curie_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["curie"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'curie' model + RETURNS (OpenAI): OpenAI instance for 'curie' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -597,6 +655,7 @@ def openai_curie_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -608,12 +667,12 @@ def openai_curie( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["curie"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'curie' model + RETURNS (OpenAI): OpenAI instance for 'curie' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -625,6 +684,7 @@ def openai_curie( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -638,12 +698,12 @@ def openai_babbage_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["babbage"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'babbage' model + RETURNS (OpenAI): OpenAI instance for 'babbage' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -655,6 +715,7 @@ def openai_babbage_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -666,12 +727,12 @@ def openai_babbage( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["babbage"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'babbage' model + RETURNS (OpenAI): OpenAI instance for 'babbage' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -683,6 +744,7 @@ def openai_babbage( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -696,12 +758,12 @@ def openai_ada_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["ada"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'ada' model + RETURNS (OpenAI): OpenAI instance for 'ada' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -713,6 +775,7 @@ def openai_ada_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -724,12 +787,12 @@ def openai_ada( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> OpenAI: """Returns OpenAI instance for 'ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (Optional[Literal["ada"]]): Model to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'ada' model + RETURNS (OpenAI): OpenAI instance for 'ada' model DOCS: https://spacy.io/api/large-language-models#models """ @@ -741,4 +804,5 @@ def openai_ada( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/palm/__init__.py b/spacy_llm/models/rest/palm/__init__.py index 1255be2f..23fe28ec 100644 --- a/spacy_llm/models/rest/palm/__init__.py +++ b/spacy_llm/models/rest/palm/__init__.py @@ -1,4 +1,4 @@ from .model import Endpoints, PaLM -from .registry import palm_bison +from .registry import palm_bison, palm_bison_v2 -__all__ = ["palm_bison", "PaLM", "Endpoints"] +__all__ = ["palm_bison", "palm_bison_v2", "PaLM", "Endpoints"] diff --git a/spacy_llm/models/rest/palm/model.py b/spacy_llm/models/rest/palm/model.py index 1a488000..b1a2657d 100644 --- a/spacy_llm/models/rest/palm/model.py +++ b/spacy_llm/models/rest/palm/model.py @@ -31,7 +31,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["What's 2+2?"]) + self([["What's 2+2?"]]) except ValueError as err: if "API key not valid" in str(err): warnings.warn( @@ -41,69 +41,87 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { "Content-Type": "application/json", "Accept": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) url = self._endpoint.format( model=self._name, api_key=self._credentials["api_key"] ) + all_api_responses: List[List[str]] = [] - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=url, - headers=headers, - json={**json_data, **self._config}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - error_message = res_content.get("error", {}).get("message", {}) - # Catching other types of HTTPErrors (e.g., "429: too many requests") - raise ValueError(f"Request to PaLM API failed: {error_message}") from ex - response = r.json() - - # PaLM returns a 'filter' key when a message was filtered due to safety concerns. - if "filters" in response: - if self._strict: - raise ValueError(f"API call failed: {response}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} - return response - - # PaLM API currently doesn't accept batch prompts, so we're making - # a request for each iteration. This approach can be prone to rate limit - # errors. In practice, you can adjust _max_request_time so that the - # timeout is larger. - uses_chat = "chat" in self._name - responses = [ - _request( - { - "prompt": {"text": prompt} - if not uses_chat - else {"messages": [{"content": prompt}]} - } - ) - for prompt in prompts - ] - for response in responses: - if "candidates" in response: - # Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a - # single value. In this case, we will just return the very first output. - api_responses.append( - response["candidates"][0].get( - "content" if uses_chat else "output", srsly.json_dumps(response) - ) + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=url, + headers=headers, + json={**json_data, **self._config}, + timeout=self._max_request_time, ) - else: - api_responses.append(srsly.json_dumps(response)) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + error_message = res_content.get("error", {}).get("message", {}) + # Catching other types of HTTPErrors (e.g., "429: too many requests") + raise ValueError( + f"Request to PaLM API failed: {error_message}" + ) from ex + response = r.json() + + # PaLM returns a 'filter' key when a message was filtered due to safety concerns. + if "filters" in response: + if self._strict: + raise ValueError(f"API call failed: {response}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + + return response - return api_responses + # PaLM API currently doesn't accept batch prompts, so we're making + # a request for each iteration. This approach can be prone to rate limit + # errors. In practice, you can adjust _max_request_time so that the + # timeout is larger. + uses_chat = "chat" in self._name + responses = [ + _request( + { + "prompt": {"text": prompt} + if not uses_chat + else {"messages": [{"content": prompt}]} + } + ) + for prompt in prompts_for_doc + ] + for response in responses: + if "candidates" in response: + # Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a + # single value. In this case, we will just return the very first output. + api_responses.append( + response["candidates"][0].get( + "content" if uses_chat else "output", + srsly.json_dumps(response), + ) + ) + else: + api_responses.append(srsly.json_dumps(response)) + + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "text-bison-001": 8192, + "chat-bison-001": 8192, + } diff --git a/spacy_llm/models/rest/palm/registry.py b/spacy_llm/models/rest/palm/registry.py index 9a56576a..1e68faed 100644 --- a/spacy_llm/models/rest/palm/registry.py +++ b/spacy_llm/models/rest/palm/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,6 +7,45 @@ from .model import Endpoints, PaLM +@registry.llm_models("spacy.PaLM.v2") +def palm_bison_v2( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), + name: Literal["chat-bison-001", "text-bison-001"] = "text-bison-001", # noqa: F821 + strict: bool = PaLM.DEFAULT_STRICT, + max_tries: int = PaLM.DEFAULT_MAX_TRIES, + interval: float = PaLM.DEFAULT_INTERVAL, + max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Google instance for PaLM Bison model using REST to prompt API. + name (Literal["chat-bison-001", "text-bison-001"]): Model to use. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (PaLM): PaLM instance for Bison model. + """ + return PaLM( + name=name, + endpoint=Endpoints.TEXT.value + if name in {"text-bison-001"} + else Endpoints.MSG.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.PaLM.v1") def palm_bison( config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), @@ -15,7 +54,7 @@ def palm_bison( max_tries: int = PaLM.DEFAULT_MAX_TRIES, interval: float = PaLM.DEFAULT_INTERVAL, max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> PaLM: """Returns Google instance for PaLM Bison model using REST to prompt API. name (Literal["chat-bison-001", "text-bison-001"]): Model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. @@ -27,7 +66,7 @@ def palm_bison( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API. + RETURNS (PaLM): PaLM instance for Bison model. """ return PaLM( name=name, @@ -39,4 +78,5 @@ def palm_bison( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 70003a8b..e239f27f 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -1,4 +1,5 @@ import logging +import warnings from collections import defaultdict from itertools import tee from pathlib import Path @@ -15,8 +16,9 @@ from .. import registry # noqa: F401 from ..compat import TypedDict -from ..ty import Cache, LabeledTask, LLMTask, PromptExecutorType, ScorableTask -from ..ty import Serializable, validate_type_consistency +from ..ty import Cache, LabeledTask, LLMTask, ModelWithContextLength +from ..ty import PromptExecutorType, ScorableTask, Serializable, supports_sharding +from ..ty import validate_type_consistency logger = logging.getLogger("spacy_llm") logger.addHandler(logging.NullHandler()) @@ -112,8 +114,8 @@ def __init__( name (str): The component instance name, used to add entries to the losses during training. vocab (Vocab): Pipeline vocabulary. - task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses into - structured information and set that back on the docs. + task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses + into structured information and set that back on the docs. model (Callable[[Iterable[Any]], Iterable[Any]]]): Callable querying the specified LLM API. cache (Cache): Cache to use for caching prompts and responses per doc (batch). save_io (bool): Whether to save LLM I/O (prompts and responses) in the `Doc._.llm_io` custom extension. @@ -131,6 +133,20 @@ def __init__( if isinstance(self._task, Initializable): self.initialize = self._task.initialize + self._check_sharding() + + def _check_sharding(self): + context_length: Optional[int] = None + if isinstance(self._model, ModelWithContextLength): + context_length = self._model.context_length + if supports_sharding(self._task) and context_length is None: + warnings.warn( + "Task supports sharding, but model does not provide context length. Data won't be sharded, prompt " + "might exceed the model's context length. Set context length in your config. If you think spacy-llm" + " should provide the context length for this model automatically, report this to " + "https://github.com/explosion/spacy-llm/issues." + ) + @property def labels(self) -> Tuple[str, ...]: labels: Tuple[str, ...] = tuple() @@ -194,9 +210,10 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: """Process a batch of docs with the configured LLM model and task. If a cache is configured, only sends prompts to model for docs not found in cache. - docs (List[Doc]): Input batch of docs - RETURNS (List[Doc]): Processed batch of docs with task annotations set + docs (List[Doc]): Input batch of docs. + RETURNS (List[Doc]): Processed batch of docs with task annotations set. """ + support_sharding = supports_sharding(self._task) is_cached = [doc in self._cache for doc in docs] noncached_doc_batch = [doc for i, doc in enumerate(docs) if not is_cached[i]] if len(noncached_doc_batch) < len(docs): @@ -206,24 +223,56 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: len(noncached_doc_batch), ) + # Process uncached docs. modified_docs: Iterator[Doc] = iter(()) if len(noncached_doc_batch) > 0: n_iters = 3 if self._save_io else 2 + context_length: Optional[int] = None + if isinstance(self._model, ModelWithContextLength): + context_length = self._model.context_length + + # Only pass context length if this is a sharding task. prompts_iters = tee( - self._task.generate_prompts(noncached_doc_batch), n_iters + self._task.generate_prompts(noncached_doc_batch, context_length) # type: ignore[call-arg] + if support_sharding + else self._task.generate_prompts(noncached_doc_batch), + n_iters + 1, ) - responses_iters = tee(self._model(prompts_iters[0]), n_iters) - for prompt, response, doc in zip( - prompts_iters[1], responses_iters[1], noncached_doc_batch + responses_iters = tee( + self._model( + # Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped + # in a list to conform to the nested structure. + ( + elem[0] if support_sharding else [elem] + for elem in prompts_iters[0] + ) + ), + n_iters, + ) + + for prompt_data, response, doc in zip( + prompts_iters[1], responses_iters[0], noncached_doc_batch ): - logger.debug("Generated prompt for doc: %s\n%s", doc.text, prompt) + logger.debug( + "Generated prompt for doc: %s\n%s", + doc.text, + prompt_data[0] if support_sharding else prompt_data, + ) logger.debug("LLM response for doc: %s\n%s", doc.text, response) - modified_docs = iter( - self._task.parse_responses(noncached_doc_batch, responses_iters[0]) + resp = list( + self._task.parse_responses( + ( + elem[1] if support_sharding else noncached_doc_batch[i] + for i, elem in enumerate(prompts_iters[2]) + ), + responses_iters[1], + ) ) + modified_docs = iter(resp) - final_docs = [] + noncached_doc_batch_iter = iter(noncached_doc_batch) + final_docs: List[Doc] = [] for i, doc in enumerate(docs): if is_cached[i]: cached_doc = self._cache[doc] @@ -233,14 +282,37 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: else: doc = next(modified_docs) + # Merge with doc's prior custom data. + noncached_doc = next(noncached_doc_batch_iter) + for extension in dir(noncached_doc._): + if not Doc.has_extension(extension): + Doc.set_extension(extension, default=None) + # Don't overwrite any non-None extension values in new doc. + if getattr(doc._, extension) is None: + setattr(doc._, extension, getattr(noncached_doc._, extension)) + doc.user_data = {**noncached_doc.user_data, **doc.user_data} + doc._context = noncached_doc._context + + # Save raw IO (prompt and response), if save_io is True. if self._save_io: # Make sure the `llm_io` field is set doc.user_data["llm_io"] = doc.user_data.get( "llm_io", defaultdict(dict) ) llm_io = doc.user_data["llm_io"][self._name] - llm_io["prompt"] = str(next(prompts_iters[2])) - llm_io["response"] = str(next(responses_iters[2])) + next_prompt = next(prompts_iters[-1]) + if support_sharding: + llm_io["prompt"] = [ + str(shard_prompt) for shard_prompt in next_prompt[0] + ] + llm_io["response"] = [ + str(shard_response) + for shard_response in next(responses_iters[-1]) + ] + else: + llm_io["prompt"] = str(next_prompt) + # Models always return nested responses. For non-sharding tasks this will always be a 1-list. + llm_io["response"] = str(next(responses_iters[-1])[0]) self._cache.add(doc) final_docs.append(doc) @@ -261,7 +333,7 @@ def to_bytes( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda: self._task.to_bytes(exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda: self._task.to_bytes(exclude=exclude) # type: ignore[attr-defined, union-attr] if isinstance(self._model, Serializable): serialize["model"] = lambda: self._model.to_bytes(exclude=exclude) # type: ignore[attr-defined] @@ -283,9 +355,9 @@ def from_bytes( deserialize = {} if isinstance(self._task, Serializable): - deserialize["task"] = lambda b: self._task.from_bytes(b, exclude=exclude) # type: ignore[attr-defined] + deserialize["task"] = lambda b: self._task.from_bytes(b, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - deserialize["model"] = lambda b: self._model.from_bytes(b, exclude=exclude) # type: ignore[attr-defined] + deserialize["model"] = lambda b: self._model.from_bytes(b, exclude=exclude) # type: ignore[attr-defined,union-attr] util.from_bytes(bytes_data, deserialize, exclude) return self @@ -301,9 +373,9 @@ def to_disk( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda p: self._task.to_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda p: self._task.to_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - serialize["model"] = lambda p: self._model.to_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["model"] = lambda p: self._model.to_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] util.to_disk(path, serialize, exclude) @@ -319,9 +391,9 @@ def from_disk( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda p: self._task.from_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda p: self._task.from_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - serialize["model"] = lambda p: self._model.from_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["model"] = lambda p: self._model.from_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] util.from_disk(path, serialize, exclude) return self diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 5f399733..1b297142 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -6,7 +6,7 @@ from .entity_linker import EntityLinkerTask, make_entitylinker_task from .lemma import LemmaTask, make_lemma_task from .ner import NERTask, make_ner_task_v3 -from .noop import NoopTask, make_noop_task +from .noop import NoopTask, ShardingNoopTask, make_noop_task, make_noopnoshards_task from .rel import RELTask, make_rel_task from .sentiment import SentimentTask, make_sentiment_task from .spancat import SpanCatTask, make_spancat_task_v3 @@ -42,6 +42,7 @@ "make_lemma_task", "make_ner_task_v3", "make_noop_task", + "make_noopnoshards_task", "make_rel_task", "make_sentiment_task", "make_spancat_task_v3", @@ -54,6 +55,7 @@ "NoopTask", "RELTask", "SentimentTask", + "ShardingNoopTask", "SpanCatTask", "SummarizationTask", "TextCatTask", diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index e879c52d..82a182dd 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -1,4 +1,5 @@ import abc +from itertools import tee from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, cast @@ -10,7 +11,7 @@ from ..compat import Self from ..registry import lowercase_normalizer -from ..ty import FewshotExample, TaskResponseParser +from ..ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser class BuiltinTask(abc.ABC): @@ -22,10 +23,10 @@ class BuiltinTask(abc.ABC): - initializable (in line with other spaCy components) - (de-)serialization - On the relation of BuiltinTask to LLMTask: the latter specifies the minimal contract a task implementation + On the relation of BuiltinTask to ShardingLLMTask: the latter specifies the minimal contract a task implementation has to fulfill, whereas a BuiltinTask requires (and offers) functionality beyond that. The rationale behind that is - that built-in tasks should provide as smooth a usage experience as possible while still making it as easy as possible - for users to write their own, custom tasks. + that built-in tasks should provide as smooth a usage experience as possible while still making it as easy as + possible for users to write their own, custom tasks. """ def __init__( @@ -34,36 +35,75 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], ): """Initializes task. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. """ self._parse_responses = parse_responses self._prompt_examples = prompt_examples or [] self._template = template self._prompt_example_type = prompt_example_type + self._shard_mapper = shard_mapper + self._shard_reducer = shard_reducer - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[Any], Iterable[Doc]]]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. - RETURNS (Iterable[Any]): Iterable with one prompt per doc. + ontext_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, + if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. + RETURNS (Iterable[Tuple[Iterable[Any], Iterable[Doc]]]): Iterable with one to n prompts per doc (multiple + prompts in case of multiple shards) and the corresponding shards. The relationship between shard and prompt + is 1:1. """ environment = jinja2.Environment() _template = environment.from_string(self._template) - for doc in self._preprocess_docs_for_prompt(docs): - prompt = _template.render( - text=doc.text, + + def render_template(shard: Doc, i_shard: int, i_doc: int, n_shards: int) -> str: + """Renders template for a given doc (shard). + shard (Doc): Doc shard. Note that if the prompt is small enough to fit within the model's context window, + there will only be one shard, which is identical to the original doc. + i_shard (int): Shard index (w.r.t. shard's Doc instance). + i_doc (int): Doc index. + n_shards (int): Total number of shards. + RETURNS (str): Rendered template. + """ + return _template.render( + text=shard.text, prompt_examples=self._prompt_examples, - **self._prompt_data, + **self._get_prompt_data(shard, i_shard, i_doc, n_shards), ) - yield prompt - @property - def _prompt_data(self) -> Dict[str, Any]: - """Returns data injected into prompt template. No-op if not overridden by inheriting task class. + for _i_doc, _doc in enumerate(self._preprocess_docs_for_prompt(docs)): + # If no context length provided (e. g. because models don't provide it): don't shard. + shards = ( + self._shard_mapper(_doc, _i_doc, context_length, render_template) + if context_length is not None + else [_doc] + ) + shards = list(shards) + yield [ + render_template(_shard, _i_shard, _i_doc, len(shards)) + for _i_shard, _shard in enumerate(shards) + ], shards + + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: + """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data + returned by this might be static (i. e. the same for all doc shards) or dynamic (contingent on the doc shard). + shard (Doc): Doc (shard) for which prompt data should be fetched. + i_shard (int): Shard index (w.r.t. shard's Doc instance). + i_doc (int): Doc index. + n_shards (int): Total number of shards. RETURNS (Dict[str, Any]): Data injected into prompt template. """ return {} @@ -77,12 +117,12 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: @abc.abstractmethod def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[Any] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[Any]] ) -> Iterable[Doc]: """ Parses LLM responses. - docs (Iterable[Doc]): Docs to map responses into. - responses ([Iterable[Any]]): LLM responses. + shards (Iterable[Iterable[Doc]]): Doc shards to map responses into. + responses ([Iterable[Iterable[Any]]]): LLM responses per doc. RETURNS (Iterable[Doc]]): Updated docs. """ @@ -113,7 +153,6 @@ def get_cfg(self) -> Dict[str, Any]: def set_cfg(self, cfg: Dict[str, Any]) -> None: """Deserialize the task's configuration attributes. - cfg (Dict[str, Any]): dictionary containing configuration attributes. """ for key, value in cfg.items(): @@ -126,7 +165,6 @@ def _get_prompt_examples(self) -> List[Dict[str, Any]]: def _set_prompt_examples(self, examples: List[Dict[str, Any]]) -> None: """Set prompt examples. - examples (List[Dict[str, Any]]): prompt examples. """ self._prompt_examples = [ @@ -162,7 +200,6 @@ def from_bytes( exclude (Tuple[str]): Names of properties to exclude from deserialization. RETURNS (BuiltinTask): Modified BuiltinTask instance. """ - deserialize = { "cfg": lambda b: self.set_cfg(srsly.json_loads(b)), "prompt_examples": lambda b: self._set_prompt_examples( @@ -184,7 +221,6 @@ def to_disk( path (Path): A path (currently unused). exclude (Tuple): Names of properties to exclude from serialization. """ - serialize = { "cfg": lambda p: srsly.write_json(p, self.get_cfg()), "prompt_examples": lambda p: srsly.write_msgpack( @@ -235,6 +271,18 @@ def _check_extension(cls, extension: str) -> None: if not Doc.has_extension(extension): Doc.set_extension(extension, default=[]) + @staticmethod + def _tee_2d_iterable( + data: Iterable[Iterable[Any]], n: int + ) -> Tuple[Iterable[List[Doc]], ...]: + """Tees two-dimensional Iterable. As Iterables in the nested iterables get consumed with the first access, we + need to materialize them - this is done by converting them to a list. + data (Iterable[Iterable[Any]]): Data to tee. + n (int): Number of tees to return. + RETURNS (Tuple[Iterable[List[Doc]], ...]): n-sized tuple of Iterables with inner Iterables converted to Lists. + """ + return tee((list(inner_data) for inner_data in data), n) + class BuiltinTaskWithLabels(BuiltinTask, abc.ABC): """Built-in tasks with labels.""" @@ -245,6 +293,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], labels: List[str], label_definitions: Optional[Dict[str, str]], normalizer: Optional[Callable[[str], str]], @@ -255,6 +305,8 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. labels (List[str]): List of labels to pass to the template. Leave empty to (optionally) populate it at initialization time. label_definitions (Optional[Dict[str, str]]): Map of label -> description @@ -268,6 +320,8 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._normalizer = normalizer if normalizer else lowercase_normalizer() self._label_dict = { diff --git a/spacy_llm/tasks/entity_linker/parser.py b/spacy_llm/tasks/entity_linker/parser.py index b3c4076a..54d1c19c 100644 --- a/spacy_llm/tasks/entity_linker/parser.py +++ b/spacy_llm/tasks/entity_linker/parser.py @@ -8,36 +8,47 @@ def parse_responses_v1( - task: EntityLinkerTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: EntityLinkerTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[List[List[Span]]]: """Parses LLM responses for spacy.EntityLinker.v1. task (EntityLinkerTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[Span]]): Entity spans per doc. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[List[List[Span]]): Entity spans per shard. """ - for i_doc, (doc, prompt_response) in enumerate(zip(docs, responses)): - solutions = [ - sol.replace("::: ", "")[1:-1] - for sol in re.findall(r"::: <.*>", prompt_response) - ] - - # Set ents anew by copying them and specifying the KB ID. - ents = [ - ent - for i_ent, ent in enumerate(doc.ents) - if task.has_ent_cands[i_doc][i_ent] - ] - yield [ - Span( - doc=doc, - start=ent.start, - end=ent.end, - label=ent.label, - vector=ent.vector, - vector_norm=ent.vector_norm, - kb_id=solution if solution != "NIL" else EntityLinker.NIL, + for i_doc, (shards_for_doc, responses_for_doc) in enumerate(zip(shards, responses)): + results_for_doc: List[List[Span]] = [] + for i_shard, (shard, response) in enumerate( + zip(shards_for_doc, responses_for_doc) + ): + solutions = [ + sol.replace("::: ", "")[1:-1] + for sol in re.findall(r"::: <.*>", response) + ] + + # Set ents anew by copying them and specifying the KB ID. + ents = [ + ent + for i_ent, ent in enumerate(shard.ents) + if task.has_ent_cands_by_shard[i_doc][i_shard][i_ent] + ] + + results_for_doc.append( + [ + Span( + doc=shard, + start=ent.start, + end=ent.end, + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=solution if solution != "NIL" else EntityLinker.NIL, + ) + for ent, solution in zip(ents, solutions) + ] ) - for ent, solution in zip(ents, solutions) - ] + + yield results_for_doc diff --git a/spacy_llm/tasks/entity_linker/registry.py b/spacy_llm/tasks/entity_linker/registry.py index 10e34ed0..df98b6f2 100644 --- a/spacy_llm/tasks/entity_linker/registry.py +++ b/spacy_llm/tasks/entity_linker/registry.py @@ -5,12 +5,15 @@ from spacy.scorer import Scorer from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .candidate_selector import KBCandidateSelector from .parser import parse_responses_v1 from .task import DEFAULT_EL_TEMPLATE_V1, EntityLinkerTask from .ty import EntDescReader, InMemoryLookupKBLoader -from .util import ELExample, KBFileLoader, KBObjectLoader, ent_desc_reader_csv, score +from .util import ELExample, KBFileLoader, KBObjectLoader, ent_desc_reader_csv +from .util import reduce_shards_to_doc, score @registry.llm_tasks("spacy.EntityLinker.v1") @@ -19,6 +22,8 @@ def make_entitylinker_task( parse_responses: Optional[TaskResponseParser[EntityLinkerTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, scorer: Optional[Scorer] = None, ): """EntityLinker.v1 task factory. @@ -28,6 +33,8 @@ def make_entitylinker_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples @@ -50,6 +57,8 @@ def make_entitylinker_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), scorer=scorer or score, ) @@ -114,3 +123,8 @@ def make_kb_file_loader(path: Union[str, Path]) -> KBFileLoader: RETURNS (KBFileLoader): Loader instance. """ return KBFileLoader(path=path) + + +@registry.llm_misc("spacy.EntityLinkerShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index c17dfcd7..86426ed0 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -1,13 +1,12 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type -import jinja2 from spacy import Language, Vocab from spacy.pipeline import EntityLinker from spacy.tokens import Doc, Span from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template from .ty import CandidateSelector, Entity, InitializableCandidateSelector @@ -22,6 +21,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], prompt_examples: Optional[List[FewshotExample[Self]]], template: str, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], scorer: Scorer, ): """Default entity linking task. @@ -30,6 +31,8 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]]): Type to use for fewshot examples. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. """ super().__init__( @@ -37,14 +40,20 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._scorer = scorer self._candidate_selector: Optional[CandidateSelector] = None # Exclude mentions without candidates from prompt, if set. Mostly used for internal debugging. self._auto_nil = True - # Store, per doc and entity, whether candidates could be found. - self._has_ent_cands: List[List[bool]] = [] + # Store, per doc and entity, whether candidates could be found and candidates themselves. + self._has_ent_cands_by_doc: List[List[bool]] = [] + self._ents_cands_by_doc: List[List[List[Entity]]] = [] + self._has_ent_cands_by_shard: List[List[List[bool]]] = [] + self._ents_cands_by_shard: List[List[List[List[Entity]]]] = [] + self._n_shards: Optional[int] = None def initialize( self, @@ -86,80 +95,153 @@ def set_candidate_selector( if isinstance(self._candidate_selector, InitializableCandidateSelector): self._candidate_selector.initialize(vocab) - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: - environment = jinja2.Environment() - _template = environment.from_string(self._template) - # Reset auto-nil attributes for new batch of docs. - self._has_ent_cands = [] + def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: + ( + self._ents_cands_by_doc, + self._has_ent_cands_by_doc, + ) = self._find_entity_candidates(docs) + # Reset shard-wise candidate info. Will be set for each shard individually in _get_prompt_data(). We cannot + # update it here, as we don't know yet how the shards will look like. + self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._n_shards = None - for i_doc, doc in enumerate(docs): - cands_ents, _ = self.fetch_entity_info(doc) + return [ + EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) + for i, doc in enumerate(docs) + ] + + def _find_entity_candidates( + self, docs: Iterable[Doc] + ) -> Tuple[List[List[List[Entity]]], List[List[bool]]]: + """Determine entity candidates for all entity mentions in docs. + docs (Iterable[Doc]): Docs with entities to select candidates for. + RETURNS (Tuple[List[List[List[Entity]]], List[List[bool]]]): (1) list of candidate entities for each doc and + entity, (2) list of flag whether candidates could be found per each doc and entitiy. + """ + ents_cands: List[List[List[Entity]]] = [] + has_cands: List[List[bool]] = [] + + for doc in docs: + ents_cands.append(self.fetch_entity_info(doc)[0]) # Determine which ents have candidates and should be included in prompt. - has_cands = [ - {cand_ent.id for cand_ent in cand_ents} != {EntityLinker.NIL} - or not self._auto_nil - for cand_ents in cands_ents - ] - self._has_ent_cands.append(has_cands) - - # To improve: if a doc has no entities (with candidates), skip prompt altogether? - yield _template.render( - text=EntityLinkerTask.highlight_ents_in_text(doc, has_cands), - mentions_str=", ".join( - [f"*{mention}*" for hc, mention in zip(has_cands, doc.ents) if hc] - ), - mentions=[ent.text for hc, ent in zip(has_cands, doc.ents) if hc], - entity_descriptions=[ - [ent.description for ent in ents] - for hc, ents in zip(has_cands, cands_ents) - if hc - ], - entity_ids=[ - [ent.id for ent in ents] - for hc, ents in zip(has_cands, cands_ents) - if hc - ], - prompt_examples=self._prompt_examples, + has_cands.append( + [ + {cand_ent.id for cand_ent in cand_ents} != {EntityLinker.NIL} + or not self._auto_nil + for cand_ents in ents_cands[-1] + ] + ) + + return ents_cands, has_cands + + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: + # n_shards changes before reset happens in _preprocess_docs() whenever sharding mechanism varies number of + # shards. In this case we have to reset task state as well. + if n_shards != self._n_shards: + self._n_shards = n_shards + self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + + # It's not ideal that we have to run candidate selection again here - but due to (1) us wanting to know whether + # all entities have candidates before sharding and, more importantly, (2) some entities maybe being split up in + # the sharding process it's cleaner to look for candidates again. + if n_shards == 1: + # If only one shard: shard is identical to original doc, so we don't have to rerun candidate search. + ents_cands, has_cands = ( + self._ents_cands_by_doc[i_doc], + self._has_ent_cands_by_doc[i_doc], ) + else: + cands_info = self._find_entity_candidates([shard]) + ents_cands, has_cands = cands_info[0][0], cands_info[1][0] + + # Update shard-wise candidate info so it can be reused during parsing. + if len(self._ents_cands_by_shard[i_doc]) == 0: + self._ents_cands_by_shard[i_doc] = [[] for _ in range(n_shards)] + self._has_ent_cands_by_shard[i_doc] = [[] for _ in range(n_shards)] + self._ents_cands_by_shard[i_doc][i_shard] = ents_cands + self._has_ent_cands_by_shard[i_doc][i_shard] = has_cands + + return { + "mentions_str": ", ".join( + [ + f"*{mention.text}*" + for hc, mention in zip(has_cands, shard.ents) + if hc + ] + ), + "mentions": [ent.text for hc, ent in zip(has_cands, shard.ents) if hc], + "entity_descriptions": [ + [ent.description for ent in ents] + for hc, ents in zip(has_cands, ents_cands) + if hc + ], + "entity_ids": [ + [ent.id for ent in ents] + for hc, ents in zip(has_cands, ents_cands) + if hc + ], + } def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for i_doc, (doc, ent_spans) in enumerate( - zip(docs, self._parse_responses(self, docs=docs, responses=responses)) + shards_teed = self._tee_2d_iterable(shards, 2) + parsed_responses = self._parse_responses(self, shards_teed[1], responses) + + for i_doc, (shards_for_doc, ent_spans_for_doc) in enumerate( + zip(shards_teed[0], parsed_responses) ): - gen_nil_span: Callable[[Span], Span] = lambda ent: Span( # noqa: E731 - doc=doc, - start=ent.start, - end=ent.end, - label=ent.label, - vector=ent.vector, - vector_norm=ent.vector_norm, - kb_id=EntityLinker.NIL, - ) + updated_shards_for_doc: List[Doc] = [] + for i_shard, (shard, ent_spans) in enumerate( + zip(shards_for_doc, ent_spans_for_doc) + ): + gen_nil_span: Callable[[Span], Span] = lambda ent: Span( # noqa: E731 + doc=shard, + start=ent.start, + end=ent.end, + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=EntityLinker.NIL, + ) + + # If numbers of ents parsed from LLM response + ents without candidates and number of ents in doc don't + # align, skip doc (most likely LLM parsing failed, no guarantee KB IDs can be assigned to correct ents). + # This can happen when the LLM fails to list solutions for all entities. + all_entities_resolved = len(ent_spans) + sum( + [ + not is_in_prompt + for is_in_prompt in self._has_ent_cands_by_shard[i_doc][i_shard] + ] + ) == len(shard.ents) + + # Fuse entities with (i. e. inferred by the LLM) and without candidates (i. e. auto-niled). + # If entity was not included in prompt, as there were no candidates - fill in NIL for this entity. + # If numbers of inferred and auto-niled entities don't line up with total number of entities, there is + # no guaranteed way to assign a partially resolved list of entities + # correctly. + # Else: entity had candidates and was included in prompt - fill in resolved KB ID. + ent_spans_iter = iter(ent_spans) + shard.ents = [ + gen_nil_span(ent) + if not ( + all_entities_resolved + and self._has_ent_cands_by_shard[i_doc][i_shard][i_ent] + ) + else next(ent_spans_iter) + for i_ent, ent in enumerate(shard.ents) + ] + + # Remove entity highlights in shards. + updated_shards_for_doc.append( + EntityLinkerTask.unhighlight_ents_in_doc(shard) + ) - # If numbers of ents parsed from LLM response + ents without candidates and number of ents in doc don't - # align, skip doc (most likely LLM parsing failed, no guarantee KB IDs can be assigned to correct ents). - # This can happen when the LLM fails to list solutions for all entities. - all_entities_resolved = len(ent_spans) + sum( - [not is_in_prompt for is_in_prompt in self._has_ent_cands[i_doc]] - ) == len(doc.ents) - - # Fuse entities with (i. e. inferred by the LLM) and without candidates (i. e. auto-niled). - # If entity was not included in prompt, as there were no candidates - fill in NIL for this entity. - # If numbers of inferred and auto-niled entities don't line up with total number of entities, there is no - # guaranteed way to assign a partially resolved list of entities - # correctly. - # Else: entity had candidates and was included in prompt - fill in resolved KB ID. - ent_spans_iter = iter(ent_spans) - doc.ents = [ - gen_nil_span(ent) - if not (all_entities_resolved and self._has_ent_cands[i_doc][i_ent]) - else next(ent_spans_iter) - for i_ent, ent in enumerate(doc.ents) - ] - - yield doc + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def scorer(self, examples: Iterable[Example]) -> Dict[str, Any]: return self._scorer(examples) @@ -169,35 +251,133 @@ def _cfg_keys(self) -> List[str]: return ["_template"] @staticmethod - def highlight_ents_in_text( + def highlight_ents_in_doc( doc: Doc, include_ents: Optional[List[bool]] = None - ) -> str: - """Highlights entities in doc text with **. + ) -> Doc: + """Highlights entities in doc by wrapping them in **. doc (Doc): Doc whose entities are to be highlighted. include_ents (Optional[List[bool]]): Whether to include entities with the corresponding indices. If None, all are included. - RETURNS (str): Text with highlighted entities. + RETURNS (Doc): Doc with highlighted entities. """ if include_ents is not None and len(include_ents) != len(doc.ents): raise ValueError( f"`include_ents` has {len(include_ents)} entries, but {len(doc.ents)} are required." ) - text = doc.text - i = 0 + ents_to_highlight_idx = [ + i + for i, ent in enumerate(doc.ents) + if (include_ents is None or include_ents[i]) + ] + ents_idx = [(ent.start, ent.end) for ent in doc.ents] + + # Include *-marker as tokens. Update entity indices. + i_ent = 0 + new_ent_idx: List[Tuple[int, int]] = [] + token_texts: List[str] = [] + spaces: List[bool] = [] + to_highlight = i_ent in ents_to_highlight_idx + offset = 0 + + for token in doc: + if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][1]: + if to_highlight: + token_texts.append("*") + spaces.append(spaces[-1]) + spaces[-2] = False + offset += 1 + i_ent += 1 + to_highlight = i_ent in ents_to_highlight_idx + if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][0]: + if to_highlight: + token_texts.append("*") + spaces.append(False) + offset += 1 + new_ent_idx.append( + (ents_idx[i_ent][0] + offset, ents_idx[i_ent][1] + offset) + ) + token_texts.append(token.text) + spaces.append(token.whitespace_ != "") + + # Cover edge case of doc ending with entity, in which case we need to close the * wrapping. + if len(ents_to_highlight_idx) and doc.ents[ + ents_to_highlight_idx[-1] + ].end == len(doc): + token_texts.append("*") + spaces.append(False) + + # Create doc with new tokens and entities. + highlighted_doc = Doc(doc.vocab, words=token_texts, spaces=spaces) + highlighted_doc.ents = [ + Span( + doc=highlighted_doc, + start=new_ent_idx[i][0], + end=new_ent_idx[i][1], + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=ent.kb_id_, + ) + for i, ent in enumerate(doc.ents) + ] + + return highlighted_doc + + @staticmethod + def unhighlight_ents_in_doc(doc: Doc) -> Doc: + """Remove entity highlighting (* wrapping) in doc. + doc (Doc): Doc whose entities are to be highlighted. + RETURNS (Doc): Doc with highlighted entities. + """ + highlight_start_idx = { + ent.start - 1 + for ent in doc.ents + if ent.start - 1 > 0 and doc[ent.start - 1].text == "*" + } + highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"} + highlight_idx = highlight_start_idx | highlight_end_idx + + # Compute entity indices with removed highlights. + ent_idx: List[Tuple[int, int]] = [] + offset = 0 for ent in doc.ents: - # Skip if ent is not supposed to be included. - if include_ents is not None and not include_ents[i]: - continue - - text = ( - text[: ent.start_char + i * 2] - + f"*{ent.text}*" - + text[ent.end_char + i * 2 :] + is_highlighted = ent.start - 1 in highlight_start_idx + ent_idx.append( + (ent.start + offset - is_highlighted, ent.end + offset - is_highlighted) + ) + offset -= 2 * is_highlighted + + # Create doc with new tokens and entities. + tokens = [ + token + for token in doc + if not (token.i in highlight_idx and token.text == "*") + ] + unhighlighted_doc = Doc( + doc.vocab, + words=[token.text for token in tokens], + # Use original token space, if token doesn't appear after * highlight. If so, insert space unconditionally. + spaces=[ + token.whitespace_ != "" or token.i + 1 in highlight_idx + for i, token in enumerate(tokens) + ], + ) + + unhighlighted_doc.ents = [ + Span( + doc=unhighlighted_doc, + start=ent_idx[i][0], + end=ent_idx[i][1], + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=ent.kb_id_, ) - i += 1 + for i, ent in enumerate(doc.ents) + ] - return text + return unhighlighted_doc def _require_candidate_selector(self) -> None: """Raises an error if candidate selector is not available.""" @@ -246,8 +426,8 @@ def fetch_entity_info( return cand_entity_info, correct_ent_ids @property - def has_ent_cands(self) -> List[List[bool]]: - """Returns flags indicating whether documents' entities' have candidates in KB. - RETURNS (List[List[bool]]): Flags indicating whether documents' entities' have candidates in KB. + def has_ent_cands_by_shard(self) -> List[List[List[bool]]]: + """Returns flags indicating whether shards' entities' have candidates in KB. + RETURNS (List[List[List[bool]]]): Flags indicating whether shards' entities' have candidates in KB. """ - return self._has_ent_cands + return self._has_ent_cands_by_shard diff --git a/spacy_llm/tasks/entity_linker/util.py b/spacy_llm/tasks/entity_linker/util.py index dce404b6..55c44d6f 100644 --- a/spacy_llm/tasks/entity_linker/util.py +++ b/spacy_llm/tasks/entity_linker/util.py @@ -11,6 +11,7 @@ from spacy.kb import InMemoryLookupKB from spacy.pipeline import EntityLinker from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -60,7 +61,7 @@ def generate(cls, example: Example, task: EntityLinkerTask) -> Optional[Self]: assert all([sol is not None for sol in solutions]) return ELExample( - text=EntityLinkerTask.highlight_ents_in_text(example.reference), + text=EntityLinkerTask.highlight_ents_in_doc(example.reference).text, mentions=mentions, entity_descriptions=[ [ent.description for ent in ents] for ents in cands_ents @@ -196,3 +197,13 @@ def __call__(self, vocab: Vocab) -> Tuple[InMemoryLookupKB, DescFormat]: raise err return kb, {qid: entities[qid].get("desc") for qid in qids} + + +def reduce_shards_to_doc(task: EntityLinkerTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for EntityLinkerTask. + task (EntityLinkerTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # Entities are additive, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/lemma/parser.py b/spacy_llm/tasks/lemma/parser.py index 9505f9d1..d9ff7c1e 100644 --- a/spacy_llm/tasks/lemma/parser.py +++ b/spacy_llm/tasks/lemma/parser.py @@ -6,19 +6,32 @@ def parse_responses_v1( - task: LemmaTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[List[str]]]: + task: LemmaTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[List[List[List[str]]]]: """Parses LLM responses for spacy.Lemma.v1. task (LemmaTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[str]]): Lists of 2-lists (token: lemmatized token) per doc/response. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[List[List[List[str]]]]): Lists of 2-lists per token (token: lemmatized token) and shard/response + and doc. """ - for prompt_response in responses: - yield [ - [pr_part.strip() for pr_part in pr.split(":")] - for pr in prompt_response.replace("Lemmatized text:", "") - .replace("'''", "") - .strip() - .split("\n") - ] + for responses_for_doc in responses: + results_for_doc: List[List[List[str]]] = [] + for response in responses_for_doc: + results_for_shard = [ + [pr_part.strip() for pr_part in pr.split(":")] + for pr in response.replace("Lemmatized text:", "") + .replace("'''", "") + .strip() + .split("\n") + ] + results_for_doc.append( + # Malformed responses might have a length != 2, in which case they are discarded. + [ + result_for_token + for result_for_token in results_for_shard + if len(result_for_token) == 2 + ] + ) + + yield results_for_doc diff --git a/spacy_llm/tasks/lemma/registry.py b/spacy_llm/tasks/lemma/registry.py index e317e280..d4d555d3 100644 --- a/spacy_llm/tasks/lemma/registry.py +++ b/spacy_llm/tasks/lemma/registry.py @@ -1,10 +1,12 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_LEMMA_TEMPLATE_V1, LemmaTask -from .util import LemmaExample, score +from .util import LemmaExample, reduce_shards_to_doc, score @registry.llm_misc("spacy.LemmaParser.v1") @@ -17,12 +19,19 @@ def make_lemma_scorer() -> Scorer: return score +@registry.llm_misc("spacy.LemmaShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc + + @registry.llm_tasks("spacy.Lemma.v1") def make_lemma_task( template: str = DEFAULT_LEMMA_TEMPLATE_V1, parse_responses: Optional[TaskResponseParser[LemmaTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, scorer: Optional[Scorer] = None, ): """Lemma.v1 task factory. @@ -32,6 +41,9 @@ def make_lemma_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples @@ -45,5 +57,7 @@ def make_lemma_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=lemma_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), scorer=scorer or score, ) diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index c3bb2083..add263d2 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -5,7 +5,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +19,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], prompt_examples: Optional[List[FewshotExample[Self]]], template: str, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], scorer: Scorer, ): """Default lemmatization task. @@ -27,6 +29,8 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. """ super().__init__( @@ -34,25 +38,36 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._scorer = scorer def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, lemmas in zip(docs, self._parse_responses(self, docs, responses)): - tokens = [token for token in doc] - # If numbers of tokens recognized by spaCy and returned by LLM don't match, we don't attempt a partial - # match. - if len(tokens) != len(lemmas): - yield doc + shards_teed = self._tee_2d_iterable(shards, 2) + for shards_for_doc, lemmas_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + updated_shards_for_doc: List[Doc] = [] - # Assign lemmas. - for token, lemma_info in zip(tokens, lemmas): - if len(lemma_info) > 0: - token.lemma_ = lemma_info[1] + for shard, lemmas in zip(shards_for_doc, lemmas_for_doc): + tokens = [token for token in shard] + # If numbers of tokens recognized by spaCy and returned by LLM don't match, we don't attempt a partial + # match. + if len(tokens) != len(lemmas): + updated_shards_for_doc.append(shard) + continue - yield doc + # Assign lemmas. + for token, lemma_info in zip(tokens, lemmas): + if len(lemma_info) > 0: + token.lemma_ = lemma_info[1] + + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def initialize( self, diff --git a/spacy_llm/tasks/lemma/util.py b/spacy_llm/tasks/lemma/util.py index fde27498..a77f8507 100644 --- a/spacy_llm/tasks/lemma/util.py +++ b/spacy_llm/tasks/lemma/util.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Iterable, List, Optional from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -24,3 +25,13 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: RETURNS (Dict[str, Any]): Dict with metric name -> score. """ return Scorer.score_token_attr(examples, "lemma") + + +def reduce_shards_to_doc(task: LemmaTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for LemmaTask. + task (LemmaTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # Lemmas are token-specific, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/ner/registry.py b/spacy_llm/tasks/ner/registry.py index 55b8e2ce..d4908904 100644 --- a/spacy_llm/tasks/ner/registry.py +++ b/spacy_llm/tasks/ner/registry.py @@ -2,14 +2,21 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency, check_label_consistency_cot +from ..util.sharding import make_shard_mapper from .task import DEFAULT_NER_TEMPLATE_V1, DEFAULT_NER_TEMPLATE_V2 from .task import DEFAULT_NER_TEMPLATE_V3, NERTask, SpanTask -from .util import NERCoTExample, NERExample, score +from .util import NERCoTExample, NERExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.NERShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.NER.v1") @@ -51,6 +58,8 @@ def make_ner_task( labels=labels_list, template=DEFAULT_NER_TEMPLATE_V1, prompt_examples=span_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -111,6 +120,8 @@ def make_ner_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -129,6 +140,8 @@ def make_ner_task_v3( template: str = DEFAULT_NER_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -150,6 +163,8 @@ def make_ner_task_v3( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -169,6 +184,8 @@ def make_ner_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/ner/task.py b/spacy_llm/tasks/ner/task.py index 7cff6523..af5f7892 100644 --- a/spacy_llm/tasks/ner/task.py +++ b/spacy_llm/tasks/ner/task.py @@ -6,7 +6,7 @@ from spacy.util import filter_spans from ...compat import Literal, Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -25,6 +25,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -40,6 +42,8 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. label_definitions (Optional[Dict[str, str]]): Map of label -> description of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than @@ -59,6 +63,8 @@ def __init__( template=template, parse_responses=parse_responses, prompt_example_type=prompt_example_type, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, label_definitions=label_definitions, prompt_examples=prompt_examples, normalizer=normalizer, diff --git a/spacy_llm/tasks/ner/util.py b/spacy_llm/tasks/ner/util.py index d02b9a83..b1ce44a2 100644 --- a/spacy_llm/tasks/ner/util.py +++ b/spacy_llm/tasks/ner/util.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, Optional from spacy.scorer import get_ner_prf +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -35,3 +36,13 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: RETURNS (Dict[str, Any]): Dict with metric name -> score. """ return get_ner_prf(examples) + + +def reduce_shards_to_doc(task: NERTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for NERTask. + task (NERTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # NERTask only affects span-specific information, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/noop.py b/spacy_llm/tasks/noop.py index 044ca5ac..dff68dc1 100644 --- a/spacy_llm/tasks/noop.py +++ b/spacy_llm/tasks/noop.py @@ -1,4 +1,5 @@ -from typing import Iterable +import warnings +from typing import Iterable, Optional, Tuple from spacy.tokens import Doc @@ -9,18 +10,52 @@ @registry.llm_tasks("spacy.NoOp.v1") def make_noop_task(): + return ShardingNoopTask() + + +@registry.llm_tasks("spacy.NoOpNoShards.v1") +def make_noopnoshards_task(): return NoopTask() +class ShardingNoopTask: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: + for doc in docs: + yield [_NOOP_PROMPT], [doc] + + def parse_responses( + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] + ) -> Iterable[Doc]: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + docs = [ + Doc.from_docs(list(shards_for_doc), ensure_whitespace=True) + for shards_for_doc in shards + ] + return docs + + @property + def prompt_template(self) -> str: + return """ + This is the NoOp + prompt template + """ + + class NoopTask: def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: - for _ in docs: + for doc in docs: yield _NOOP_PROMPT def parse_responses( self, docs: Iterable[Doc], responses: Iterable[str] ) -> Iterable[Doc]: - # Not doing anything return docs @property diff --git a/spacy_llm/tasks/rel/__init__.py b/spacy_llm/tasks/rel/__init__.py index f35171a4..324126a8 100644 --- a/spacy_llm/tasks/rel/__init__.py +++ b/spacy_llm/tasks/rel/__init__.py @@ -1,7 +1,6 @@ -from .examples import RELExample from .registry import make_rel_task from .task import DEFAULT_REL_TEMPLATE, RELTask -from .util import RelationItem +from .util import RelationItem, RELExample __all__ = [ "DEFAULT_REL_TEMPLATE", diff --git a/spacy_llm/tasks/rel/examples.py b/spacy_llm/tasks/rel/examples.py deleted file mode 100644 index 3467976c..00000000 --- a/spacy_llm/tasks/rel/examples.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List, Optional - -from spacy.training import Example - -from ...compat import Self -from ...ty import FewshotExample -from .task import RELTask -from .util import EntityItem, RelationItem - - -class RELExample(FewshotExample[RELTask]): - text: str - ents: List[EntityItem] - relations: List[RelationItem] - - @classmethod - def generate(cls, example: Example, task: RELTask) -> Optional[Self]: - entities = [ - EntityItem( - start_char=ent.start_char, - end_char=ent.end_char, - label=ent.label_, - ) - for ent in example.reference.ents - ] - - return cls( - text=example.reference.text, - ents=entities, - relations=example.reference._.rel, - ) diff --git a/spacy_llm/tasks/rel/items.py b/spacy_llm/tasks/rel/items.py new file mode 100644 index 00000000..7426d8b8 --- /dev/null +++ b/spacy_llm/tasks/rel/items.py @@ -0,0 +1,19 @@ +from ...compat import BaseModel, validator + + +class RelationItem(BaseModel): + dep: int + dest: int + relation: str + + @validator("dep", "dest", pre=True) + def clean_ent(cls, value): + if isinstance(value, str): + value = value.strip("ENT") + return value + + +class EntityItem(BaseModel): + start_char: int + end_char: int + label: str diff --git a/spacy_llm/tasks/rel/parser.py b/spacy_llm/tasks/rel/parser.py index 890a6aac..27ede457 100644 --- a/spacy_llm/tasks/rel/parser.py +++ b/spacy_llm/tasks/rel/parser.py @@ -9,28 +9,32 @@ def parse_responses_v1( - task: RELTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[RelationItem]]: + task: RELTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[RelationItem]]]: """Parses LLM responses for spacy.REL.v1. task (RELTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[RelationItem]]): List of RelationItem instances per doc/response. + docs (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[RelationItem]]]): List of RelationItem instances per shard/response. """ - for response, doc in zip(responses, docs): - relations: List[RelationItem] = [] - for line in response.strip().split("\n"): - try: - rel_item = RelationItem.parse_raw(line) - if 0 <= rel_item.dep < len(doc.ents) and 0 <= rel_item.dest < len( - doc.ents - ): - relations.append(rel_item) - except ValidationError: - msg.warn( - "Validation issue", - line, - show=task.verbose, - ) + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[RelationItem]] = [] + for response, shard in zip(responses_for_doc, shards_for_doc): + relations: List[RelationItem] = [] + for line in response.strip().split("\n"): + try: + rel_item = RelationItem.parse_raw(line) + if 0 <= rel_item.dep < len(shard.ents) and 0 <= rel_item.dest < len( + shard.ents + ): + relations.append(rel_item) + except ValidationError: + msg.warn( + "Validation issue", + line, + show=task.verbose, + ) - yield relations + results_for_doc.append(relations) + + yield results_for_doc diff --git a/spacy_llm/tasks/rel/registry.py b/spacy_llm/tasks/rel/registry.py index 2a3121fb..2399b65d 100644 --- a/spacy_llm/tasks/rel/registry.py +++ b/spacy_llm/tasks/rel/registry.py @@ -1,11 +1,18 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels -from .examples import RELExample +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_REL_TEMPLATE, RELTask +from .util import RELExample, reduce_shards_to_doc + + +@registry.llm_misc("spacy.RELShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.REL.v1") @@ -16,6 +23,8 @@ def make_rel_task( prompt_example_type: Optional[Type[FewshotExample]] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, verbose: bool = False, ) -> "RELTask": @@ -35,6 +44,8 @@ def make_rel_task( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -50,6 +61,8 @@ def make_rel_task( template=template, label_definitions=label_definitions, prompt_examples=rel_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, verbose=verbose, ) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index 81fd8917..1624455f 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -1,14 +1,14 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from spacy.language import Language -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template -from .util import RelationItem +from .items import RelationItem DEFAULT_REL_TEMPLATE: str = read_template("rel.v1") @@ -22,9 +22,22 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], verbose: bool, ): + super().__init__( + parse_responses=parse_responses, + prompt_example_type=prompt_example_type, + template=template, + prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, + labels=labels, + label_definitions=label_definitions, + normalizer=normalizer, + ) """Default REL task. Populates a `Doc._.rel` custom attribute. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. @@ -36,27 +49,22 @@ def __init__( of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than full examples, although both can be provided. - prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in + prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ - super().__init__( - parse_responses=parse_responses, - prompt_example_type=prompt_example_type, - template=template, - prompt_examples=prompt_examples, - labels=labels, - label_definitions=label_definitions, - normalizer=normalizer, - ) self._verbose = verbose self._field = "rel" def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: - return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs] + return [RELTask._preannotate(doc, True) for doc in docs] - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, @@ -64,33 +72,82 @@ def _prompt_data(self) -> Dict[str, Any]: } @staticmethod - def _preannotate(doc: Union[Doc, FewshotExample]) -> str: - """Creates a text version of the document with annotated entities.""" - offset = 0 - text = doc.text + def _preannotate( + doc: Union[Doc, FewshotExample], return_as_doc: bool = False + ) -> Union[str, Doc]: + """Creates a text version of the document with annotated entities. + doc (Union[Doc, FewshotExample]): Doc to preannotate. + return_as_doc (bool): Whether to return as doc (by default returned as text). + """ + words: List[str] = [] if len(doc.ents) else [t.text for t in doc] + spaces: List[bool] = [] if len(doc.ents) else [t.whitespace_ != "" for t in doc] + ent_indices: List[Tuple[int, int]] = [] + + # Convert RELExample into Doc for easier subsequent processing. + # todo Solve import cycle so we can expect RELExample here. + if not isinstance(doc, Doc): + assert hasattr(doc, "to_doc") and callable(doc.to_doc) + doc = doc.to_doc() if not hasattr(doc, "ents"): raise ValueError( "Prompt example type used in RELTask has to expose entities via an .ents attribute." ) + # Update token information for doc reconstruction. + last_ent_end = -1 for i, ent in enumerate(doc.ents): - end = ent.end_char - before, after = text[: end + offset], text[end + offset :] annotation = f"[ENT{i}:{ent.label_ if isinstance(doc, Doc) else ent.label}]" - offset += len(annotation) - text = f"{before}{annotation}{after}" + tokens_since_last_ent = [ + *[t for t in doc if last_ent_end <= t.i < ent.start], + *[t for t in ent], + ] + words.extend([*[t.text for t in tokens_since_last_ent], annotation]) + spaces.extend([t.whitespace_ != "" for t in tokens_since_last_ent]) + + # Adjust spaces w.r.t. added annotations, which should appear directly after entity. + spaces.append(spaces[-1]) + spaces[-2] = False + ent_indices.append((ent.start + i, ent.end + i)) + + last_ent_end = ent.end + + # Include chars after last ent. + if len(doc.ents): + tokens_since_last_ent = [t for t in doc if last_ent_end <= t.i] + words.extend([t.text for t in tokens_since_last_ent]) + spaces.extend([t.whitespace_ != "" for t in tokens_since_last_ent]) + + # Reconstruct doc. + annotated_doc = Doc(words=words, spaces=spaces, vocab=doc.vocab) + annotated_doc.ents = [ + Span( # noqa: E731 + doc=annotated_doc, + start=ent_idx[0], + end=ent_idx[1], + label=doc.ents[i].label, + vector=doc.ents[i].vector, + vector_norm=doc.ents[i].vector_norm, + kb_id=doc.ents[i].kb_id_, + ) + for i, ent_idx in enumerate(ent_indices) + ] - return text + return annotated_doc.text if not return_as_doc else annotated_doc def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_extension(self._field) - - for doc, rel_items in zip(docs, self._parse_responses(self, docs, responses)): - doc._.rel = rel_items - yield doc + shards_teed = self._tee_2d_iterable(shards, 2) + for shards_for_doc, rel_items_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, rel_items in zip(shards_for_doc, rel_items_for_doc): + shard._.rel = rel_items + + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] def initialize( self, diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index 7426d8b8..e06229d7 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -1,19 +1,102 @@ -from ...compat import BaseModel, validator +import re +import warnings +from typing import Iterable, List, Optional, Tuple +from spacy import Vocab +from spacy.tokens import Doc, Span +from spacy.training import Example -class RelationItem(BaseModel): - dep: int - dest: int - relation: str +from ...compat import Self +from ...ty import FewshotExample +from .items import EntityItem, RelationItem +from .task import RELTask - @validator("dep", "dest", pre=True) - def clean_ent(cls, value): - if isinstance(value, str): - value = value.strip("ENT") - return value +class RELExample(FewshotExample[RELTask]): + text: str + ents: List[EntityItem] + relations: List[RelationItem] -class EntityItem(BaseModel): - start_char: int - end_char: int - label: str + @classmethod + def generate(cls, example: Example, task: RELTask) -> Optional[Self]: + entities = [ + EntityItem( + start_char=ent.start_char, + end_char=ent.end_char, + label=ent.label_, + ) + for ent in example.reference.ents + ] + + return cls( + text=example.reference.text, + ents=entities, + relations=example.reference._.rel, + ) + + def to_doc(self) -> Doc: + """Returns Doc representation of example instance. Note that relations are in user_data["rel"]. + field (str): Doc field to store relations in. + RETURNS (Doc): Representation as doc. + """ + punct_chars_pattern = r'[]!"$%&\'()*+,./:;=#@?[\\^_`{|}~-]+' + text = re.sub(punct_chars_pattern, r" \g<0> ", self.text) + doc_words = text.split() + doc_spaces = [ + i < len(doc_words) - 1 + and not re.match(punct_chars_pattern, doc_words[i + 1]) + for i, word in enumerate(doc_words) + ] + doc = Doc(words=doc_words, spaces=doc_spaces, vocab=Vocab(strings=doc_words)) + + # Set entities after finding correct indices. + conv_ent_indices: List[Tuple[int, int]] = [] + if len(self.ents): + ent_idx = 0 + for token in doc: + if token.idx == self.ents[ent_idx].start_char: + conv_ent_indices.append((token.i, -1)) + if token.idx + len(token.text) == self.ents[ent_idx].end_char: + conv_ent_indices[-1] = (conv_ent_indices[-1][0], token.i + 1) + ent_idx += 1 + if ent_idx == len(self.ents): + break + + doc.ents = [ + Span( # noqa: E731 + doc=doc, + start=ent_idx[0], + end=ent_idx[1], + label=self.ents[i].label, + ) + for i, ent_idx in enumerate(conv_ent_indices) + ] + doc.user_data["rel"] = self.relations + + return doc + + +def reduce_shards_to_doc(task: RELTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for RELTask. + task (RELTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + shards = list(shards) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + + # REL information from shards can be simply appended. + setattr( + doc._, + task.field, + [rel_items for shard in shards for rel_items in getattr(shard._, task.field)], + ) + + return doc diff --git a/spacy_llm/tasks/sentiment/parser.py b/spacy_llm/tasks/sentiment/parser.py index 8365dab0..5e4ba679 100644 --- a/spacy_llm/tasks/sentiment/parser.py +++ b/spacy_llm/tasks/sentiment/parser.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional +from typing import Iterable, List, Optional from spacy.tokens import Doc @@ -6,16 +6,24 @@ def parse_responses_v1( - task: SentimentTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[Optional[float]]: + task: SentimentTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[Optional[float]]]: """Parses LLM responses for spacy.Sentiment.v1. task (SentimentTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[Optional[float]]): Sentiment score per doc/response. None on parsing error. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[Optional[float]]]): Sentiment score per shard/response. None on parsing error. """ - for prompt_response in responses: - try: - yield float("".join(prompt_response.replace("Answer:", "").strip().split())) - except ValueError: - yield None + for responses_for_doc in responses: + results_for_doc: List[Optional[float]] = [] + for response in responses_for_doc: + try: + results_for_doc.append( + float("".join(response.replace("Answer:", "").strip().split())) + ) + except ValueError: + results_for_doc.append(None) + + yield results_for_doc diff --git a/spacy_llm/tasks/sentiment/registry.py b/spacy_llm/tasks/sentiment/registry.py index ab15f151..180cccc2 100644 --- a/spacy_llm/tasks/sentiment/registry.py +++ b/spacy_llm/tasks/sentiment/registry.py @@ -1,10 +1,17 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SENTIMENT_TEMPLATE_V1, SentimentTask -from .util import SentimentExample, score +from .util import SentimentExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.SentimentShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.Sentiment.v1") @@ -13,6 +20,8 @@ def make_sentiment_task( parse_responses: Optional[TaskResponseParser[SentimentTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, field: str = "sentiment", scorer: Optional[Scorer] = None, ): @@ -24,6 +33,8 @@ def make_sentiment_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. field (str): The name of the doc extension in which to store the summary. scorer (Optional[Scorer]): Scorer function. """ @@ -38,6 +49,8 @@ def make_sentiment_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=sentiment_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), field=field, scorer=scorer or score, ) diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 57015ca8..29663aea 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -4,7 +4,8 @@ from spacy.tokens import Doc from spacy.training import Example -from ...ty import FewshotExample, Scorer, Self, TaskResponseParser +from ...ty import FewshotExample, Scorer, Self, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +20,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], scorer: Scorer, ): """Sentiment analysis task. @@ -27,13 +30,18 @@ def __init__( parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. field (str): The name of the doc extension in which to store the sentiment score. - prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in + prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. """ super().__init__( parse_responses=parse_responses, prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._field = field self._scorer = scorer @@ -63,19 +71,22 @@ def initialize( ) def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_doc_extension() + shards_teed = self._tee_2d_iterable(shards, 2) - for doc, sentiment_score in zip( - docs, self._parse_responses(self, docs, responses) + for shards_for_doc, scores_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): - try: - setattr(doc._, self._field, sentiment_score) - except ValueError: - setattr(doc._, self._field, None) + shards_for_doc = list(shards_for_doc) + for shard, score in zip(shards_for_doc, scores_for_doc): + try: + setattr(shard._, self._field, score) + except ValueError: + setattr(shard._, self._field, None) - yield doc + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] def scorer(self, examples: Iterable[Example]) -> Dict[str, Any]: return self._scorer(examples, field=self._field) diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index f53bae32..4352b62c 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -1,5 +1,7 @@ +import warnings from typing import Any, Dict, Iterable, Optional +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -19,6 +21,33 @@ def generate(cls, example: Example, task: SentimentTask) -> Optional[Self]: ) +def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SentimentTask by computing an average sentiment score weighted by shard lengths. + task (SentimentTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + shards = list(shards) + weights = [len(shard) for shard in shards] + weights = [n_tokens / sum(weights) for n_tokens in weights] + sent_scores = [getattr(shard._, task.field) for shard in shards] + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + setattr( + doc._, + task.field, + sum([score * weight for score, weight in zip(sent_scores, weights)]), + ) + + return doc + + def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: """Score sentiment accuracy in examples. examples (Iterable[Example]): Examples to score. diff --git a/spacy_llm/tasks/span/parser.py b/spacy_llm/tasks/span/parser.py index 467dcbbc..fd0c389e 100644 --- a/spacy_llm/tasks/span/parser.py +++ b/spacy_llm/tasks/span/parser.py @@ -35,35 +35,40 @@ def _format_response( def parse_responses( - task: SpanTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: SpanTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[Span]]]: """Parses LLM responses for Span tasks. task (SpanTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[Span]): Parsed spans per doc/response. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[Span]]]): Parsed spans per shard/response. """ - for doc, prompt_response in zip(docs, responses): - spans = [] - for label, phrases in _format_response( - prompt_response, task._normalizer, task._label_dict - ): - # For each phrase, find the substrings in the text - # and create a Span - offsets = find_substrings( - doc.text, - phrases, - case_sensitive=task._case_sensitive_matching, - single_match=task._single_match, - ) - for start, end in offsets: - span = doc.char_span( - start, end, alignment_mode=task._alignment_mode, label=label + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[Span]] = [] + + for shard, response in zip(shards_for_doc, responses_for_doc): + spans = [] + for label, phrases in _format_response( + response, task._normalizer, task._label_dict + ): + # For each phrase, find the substrings in the text + # and create a Span + offsets = find_substrings( + shard.text, + phrases, + case_sensitive=task._case_sensitive_matching, + single_match=task._single_match, ) - if span is not None: - spans.append(span) + for start, end in offsets: + span = shard.char_span( + start, end, alignment_mode=task._alignment_mode, label=label + ) + if span is not None: + spans.append(span) - yield spans + results_for_doc.append(spans) + + yield results_for_doc def _extract_span_reasons_cot(task: SpanTask, response: str) -> List[SpanReason]: @@ -152,19 +157,23 @@ def _find_spans_cot( def parse_responses_cot( - task: SpanTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: SpanTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[Span]]]: """Since we provide entities in a numbered list, we expect the LLM to output entities in the order they occur in the text. This parse function now incrementally finds substrings in the text and tracks the last found span's start character to ensure we don't overwrite previously found spans. task (SpanTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[Span]]): Spans to assign per doc. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[Span]]]): Spans to assign per shard. """ - for doc, llm_response in zip(docs, responses): - span_reasons = _extract_span_reasons_cot(task, llm_response) - spans = _find_spans_cot(task, doc, span_reasons) - yield spans + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[Span]] = [] + + for shard, response in zip(shards_for_doc, responses_for_doc): + span_reasons = _extract_span_reasons_cot(task, response) + results_for_doc.append(_find_spans_cot(task, shard, span_reasons)) + + yield results_for_doc diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index f19f4d20..e8bcb407 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -5,7 +5,7 @@ from spacy.tokens import Doc, Span from ...compat import Literal, Protocol, Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from . import SpanExample from .examples import SpanCoTExample @@ -33,6 +33,8 @@ def __init__( prompt_examples: Optional[ Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] ], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], description: Optional[str], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 @@ -46,6 +48,8 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, labels=labels, label_definitions=label_definitions, normalizer=normalizer, @@ -66,8 +70,9 @@ def __init__( if self._prompt_examples: self._prompt_examples = list(self._check_label_consistency(self)) - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "description": self._description, "labels": list(self._label_dict.values()), @@ -97,11 +102,18 @@ def assign_spans( raise NotImplementedError() def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, spans in zip(docs, self._parse_responses(self, docs, responses)): - self.assign_spans(doc, spans) - yield doc + shards_teed = self._tee_2d_iterable(shards, 2) + + for shards_for_doc, spans_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, spans in zip(shards_for_doc, spans_for_doc): + self.assign_spans(shard, spans) + + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/spancat/registry.py b/spacy_llm/tasks/spancat/registry.py index 33cf11dd..f5aa7180 100644 --- a/spacy_llm/tasks/spancat/registry.py +++ b/spacy_llm/tasks/spancat/registry.py @@ -2,15 +2,22 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency as check_labels from ..span.util import check_label_consistency_cot as check_labels_cot +from ..util.sharding import make_shard_mapper from .task import DEFAULT_SPANCAT_TEMPLATE_V1, DEFAULT_SPANCAT_TEMPLATE_V2 from .task import DEFAULT_SPANCAT_TEMPLATE_V3, SpanCatTask -from .util import SpanCatCoTExample, SpanCatExample, score +from .util import SpanCatCoTExample, SpanCatExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.SpanCatShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.SpanCat.v1") @@ -55,6 +62,8 @@ def make_spancat_task( prompt_example_type=example_type, template=DEFAULT_SPANCAT_TEMPLATE_V1, prompt_examples=span_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -119,6 +128,8 @@ def make_spancat_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -139,6 +150,8 @@ def make_spancat_task_v3( description: Optional[str] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -161,6 +174,8 @@ def make_spancat_task_v3( full examples, although both can be provided. examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -181,6 +196,8 @@ def make_spancat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/spancat/task.py b/spacy_llm/tasks/spancat/task.py index 76439964..b25f39e4 100644 --- a/spacy_llm/tasks/spancat/task.py +++ b/spacy_llm/tasks/spancat/task.py @@ -5,7 +5,7 @@ from spacy.training import Example from ...compat import Literal, Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -25,6 +25,8 @@ def __init__( label_definitions: Optional[Dict[str, str]], spans_key: str, prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -46,6 +48,8 @@ def __init__( full examples, although both can be provided. spans_key (str): Key of the `Doc.spans` dict to save under. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -62,6 +66,8 @@ def __init__( template=template, label_definitions=label_definitions, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/spancat/util.py b/spacy_llm/tasks/spancat/util.py index 6ffcd54c..23eec817 100644 --- a/spacy_llm/tasks/spancat/util.py +++ b/spacy_llm/tasks/spancat/util.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, Optional from spacy.pipeline.spancat import spancat_score +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -41,3 +42,13 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: spans_key=kwargs["spans_key"], allow_overlap=True, ) + + +def reduce_shards_to_doc(task: SpanCatTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SpanCatTask. + task (SpanCatTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # SpanCatTask only affects span-specific information, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/summarization/parser.py b/spacy_llm/tasks/summarization/parser.py index 0af52f6e..5f9f34cb 100644 --- a/spacy_llm/tasks/summarization/parser.py +++ b/spacy_llm/tasks/summarization/parser.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, List from spacy.tokens import Doc @@ -6,13 +6,19 @@ def parse_responses_v1( - task: SummarizationTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[str]: + task: SummarizationTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[str]]: """Parses LLM responses for spacy.Summarization.v1. task (SummarizationTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[str]): Summary per doc/response. + docs (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[str]]): Summary per shard/response. """ - for prompt_response in responses: - yield prompt_response.replace("'''", "").strip() + for responses_for_doc in responses: + results_for_doc: List[str] = [] + for response in responses_for_doc: + results_for_doc.append(response.replace("'''", "").strip()) + + yield responses_for_doc diff --git a/spacy_llm/tasks/summarization/registry.py b/spacy_llm/tasks/summarization/registry.py index 216d99bf..083a7363 100644 --- a/spacy_llm/tasks/summarization/registry.py +++ b/spacy_llm/tasks/summarization/registry.py @@ -1,10 +1,17 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SUMMARIZATION_TEMPLATE_V1, SummarizationTask -from .util import SummarizationExample +from .util import SummarizationExample, reduce_shards_to_doc + + +@registry.llm_misc("spacy.SummarizationShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.Summarization.v1") @@ -13,6 +20,8 @@ def make_summarization_task( parse_responses: Optional[TaskResponseParser[SummarizationTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, max_n_words: Optional[int] = None, field: str = "summary", ): @@ -24,6 +33,8 @@ def make_summarization_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. max_n_words (int): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. """ @@ -38,6 +49,8 @@ def make_summarization_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=span_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), max_n_words=max_n_words, field=field, ) diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index cc749ab3..c6900ce0 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -6,7 +6,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +19,8 @@ def __init__( parse_responses: TaskResponseParser[Self], prompt_example_type: Type[FewshotExample[Self]], template: str, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], max_n_words: Optional[int], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], @@ -28,6 +30,8 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. max_n_words (Optional[int]): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. @@ -37,6 +41,8 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._max_n_words = max_n_words self._field = field @@ -78,23 +84,32 @@ def _check_prompt_example_summary_len(self) -> None: f"LLM will likely produce responses that are too long." ) - @property - def _prompt_data(self) -> Dict[str, Any]: - """Returns data injected into prompt template. No-op if not overridden by inheriting task class. - RETURNS (Dict[str, Any]): Data injected into prompt template. - """ + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: if self._check_example_summaries: self._check_prompt_example_summary_len() self._check_example_summaries = False - return {"max_n_words": self._max_n_words} + return { + "max_n_words": int(self._max_n_words / n_shards) + if self._max_n_words is not None + else None + } def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, summary in zip(docs, self._parse_responses(self, docs, responses)): - setattr(doc._, self._field, summary) - yield doc + shards_teed = self._tee_2d_iterable(shards, 2) + + for shards_for_doc, summaries_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, summary in zip(shards_for_doc, summaries_for_doc): + setattr(shard._, self._field, summary) + + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/summarization/util.py b/spacy_llm/tasks/summarization/util.py index 12fd1aa9..9ee479a7 100644 --- a/spacy_llm/tasks/summarization/util.py +++ b/spacy_llm/tasks/summarization/util.py @@ -1,5 +1,7 @@ -from typing import Optional +import warnings +from typing import Iterable, Optional +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -17,3 +19,28 @@ def generate(cls, example: Example, task: SummarizationTask) -> Optional[Self]: text=example.reference.text, summary=getattr(example.reference._, task.field), ) + + +def reduce_shards_to_doc(task: SummarizationTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SummarizationTask. + task (SummarizationTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + shards = list(shards) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + doc = Doc.from_docs(list(shards), ensure_whitespace=True) + + # Summaries are per shard, so we can merge. Number of shards is considered in max. number of words. This means that + # the resulting summaries will be per shard, which should be an approximately correct summary still. + setattr( + doc._, task.field, " ".join([getattr(shard._, task.field) for shard in shards]) + ) + + return doc diff --git a/spacy_llm/tasks/textcat/parser.py b/spacy_llm/tasks/textcat/parser.py index 24228f7f..ee8f9ddc 100644 --- a/spacy_llm/tasks/textcat/parser.py +++ b/spacy_llm/tasks/textcat/parser.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable +from typing import Dict, Iterable, List from spacy.tokens import Doc from wasabi import msg @@ -7,40 +7,47 @@ def parse_responses_v1_v2_v3( - task: TextCatTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[Dict[str, float]]: + task: TextCatTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[Dict[str, float]]]: """Parses LLM responses for spacy.TextCat.v1, v2 and v3 task (LemmaTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Dict[str, float]): TextCat scores per class. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[Dict[str, float]]]): TextCat scores per shard and class. """ - for response in responses: - categories: Dict[str, float] - response = response.strip() - if task.use_binary: - # Binary classification: We only have one label - label: str = list(task.label_dict.values())[0] - score = 1.0 if response.upper() == "POS" else 0.0 - categories = {label: score} - else: - # Multilabel classification - categories = {label: 0.0 for label in task.label_dict.values()} - - pred_labels = response.split(",") - if task.exclusive_classes and len(pred_labels) > 1: - # Don't use anything but raise a debug message - # Don't raise an error. Let user abort if they want to. - msg.text( - f"LLM returned multiple labels for this exclusive task: {pred_labels}.", - " Will store an empty label instead.", - show=task.verbose, - ) - pred_labels = [] - - for pred in pred_labels: - if task.normalizer(pred.strip()) in task.label_dict: - category = task.label_dict[task.normalizer(pred.strip())] - categories[category] = 1.0 - - yield categories + for response_for_doc in responses: + results_for_doc: List[Dict[str, float]] = [] + + for response in response_for_doc: + categories: Dict[str, float] + response = response.strip() + if task.use_binary: + # Binary classification: We only have one label + label: str = list(task.label_dict.values())[0] + score = 1.0 if response.upper() == "POS" else 0.0 + categories = {label: score} + else: + # Multilabel classification + categories = {label: 0.0 for label in task.label_dict.values()} + + pred_labels = response.split(",") + if task.exclusive_classes and len(pred_labels) > 1: + # Don't use anything but raise a debug message + # Don't raise an error. Let user abort if they want to. + msg.text( + f"LLM returned multiple labels for this exclusive task: {pred_labels}.", + " Will store an empty label instead.", + show=task.verbose, + ) + pred_labels = [] + + for pred in pred_labels: + if task.normalizer(pred.strip()) in task.label_dict: + category = task.label_dict[task.normalizer(pred.strip())] + categories[category] = 1.0 + + results_for_doc.append(categories) + + yield results_for_doc diff --git a/spacy_llm/tasks/textcat/registry.py b/spacy_llm/tasks/textcat/registry.py index 7f97709c..67885025 100644 --- a/spacy_llm/tasks/textcat/registry.py +++ b/spacy_llm/tasks/textcat/registry.py @@ -1,12 +1,19 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1_v2_v3 from .task import DEFAULT_TEXTCAT_TEMPLATE_V1, DEFAULT_TEXTCAT_TEMPLATE_V2 from .task import DEFAULT_TEXTCAT_TEMPLATE_V3, TextCatTask -from .util import TextCatExample, score +from .util import TextCatExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.TextCatShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.TextCat.v1") @@ -62,6 +69,8 @@ def make_textcat_task( labels=labels_list, template=DEFAULT_TEXTCAT_TEMPLATE_V1, prompt_examples=textcat_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -128,6 +137,8 @@ def make_textcat_task_v2( labels=labels_list, template=template, prompt_examples=textcat_examples, + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -145,6 +156,8 @@ def make_textcat_task_v3( template: str = DEFAULT_TEXTCAT_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, exclusive_classes: bool = False, allow_none: bool = True, @@ -177,6 +190,8 @@ def make_textcat_task_v3( These descriptions are added to the prompt to help instruct the LLM on what to extract. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -199,6 +214,8 @@ def make_textcat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=textcat_examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 63d7ec7e..9e21238a 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -6,7 +6,7 @@ from wasabi import msg from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template @@ -24,6 +24,8 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], exclusive_classes: bool, allow_none: bool, @@ -53,6 +55,8 @@ def __init__( label_definitions (Optional[Dict[str, str]]): Optional dict mapping a label to a description of that label. These descriptions are added to the prompt to help instruct the LLM on what to extract. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -65,6 +69,8 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, labels=labels, label_definitions=label_definitions, normalizer=normalizer, @@ -83,8 +89,9 @@ def __init__( ) self._exclusive_classes = True - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, @@ -93,11 +100,19 @@ def _prompt_data(self) -> Dict[str, Any]: } def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, cats in zip(docs, self._parse_responses(self, docs, responses)): - doc.cats = cats - yield doc + shards_teed = self._tee_2d_iterable(shards, 2) + for shards_for_doc, cats_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + updated_shards_for_doc: List[Doc] = [] + + for shard, cats in zip(shards_for_doc, cats_for_doc): + shard.cats = cats + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def scorer( self, diff --git a/spacy_llm/tasks/textcat/util.py b/spacy_llm/tasks/textcat/util.py index 992c9bb2..291f5d29 100644 --- a/spacy_llm/tasks/textcat/util.py +++ b/spacy_llm/tasks/textcat/util.py @@ -1,6 +1,9 @@ -from typing import Any, Dict, Iterable, Optional +import warnings +from collections import defaultdict +from typing import Any, DefaultDict, Dict, Iterable, Optional from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -46,3 +49,31 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: labels=kwargs["labels"], multi_label=kwargs["multi_label"], ) + + +def reduce_shards_to_doc(task: TextCatTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for TextCatTask. + task (TextCatTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + shards = list(shards) + + # Compute average sum per category weighted by shard length. + weights = [len(shard) for shard in shards] + weights = [n_tokens / sum(weights) for n_tokens in weights] + all_cats: DefaultDict[str, float] = defaultdict(lambda: 0) + for weight, shard in zip(weights, shards): + for cat, cat_score in shard.cats.items(): + all_cats[cat] += cat_score * weight + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + doc.cats = all_cats + + return doc diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py new file mode 100644 index 00000000..4b9c9a9a --- /dev/null +++ b/spacy_llm/tasks/util/sharding.py @@ -0,0 +1,103 @@ +from typing import Callable, Iterable, List, Optional, Union + +from spacy.tokens import Doc + +from ...registry import registry +from ...ty import NTokenEstimator, ShardMapper + + +@registry.llm_misc("spacy.NTokenEstimator.v1") +def make_n_token_estimator() -> NTokenEstimator: + """Generates Callable estimating the number of tokens in a given string. + # todo improve default tokenization (allow language code to do tokenization with pretrained spacy model) + RETURNS (NTokenEstimator): Callable estimating the number of tokens in a given string. + """ + + def count_tokens_by_spaces(value: str) -> int: + return len(value.split()) + + return count_tokens_by_spaces + + +@registry.llm_misc("spacy.ShardMapper.v1") +def make_shard_mapper( + n_token_estimator: Optional[NTokenEstimator] = None, + buffer_frac: float = 1.1, +) -> ShardMapper: + """Generates Callable mapping doc to doc shards fitting within context length. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + buffer_frac (float): Buffer to consider in assessment of whether prompt fits into context. E. g. if value is 1.1, + prompt length * 1.1 will be compared with the context length. + todo sharding would be better with sentences instead of tokens, but this requires some form of sentence + splitting we can't rely one...maybe checking for sentences and/or as optional arg? + RETURNS (ShardMapper): Callable mapping doc to doc shards fitting within context length. + """ + n_tok_est: NTokenEstimator = n_token_estimator or make_n_token_estimator() + + def map_doc_to_shards( + doc: Doc, + i_doc: int, + context_length: int, + render_template: Callable[[Doc, int, int, int], str], + ) -> Union[Iterable[Doc], Doc]: + prompt = render_template(doc, 0, i_doc, 1) + + # If prompt with complete doc too long: split in shards. + if n_tok_est(prompt) * buffer_frac > context_length: + shards: List[Doc] = [] + # Prompt length unfortunately can't be exacted computed prior to rendering the prompt, as external + # information not present in the doc (e. g. entity description for EL prompts) may be injected. + # For this reason we follow a greedy binary search heuristic, if the fully rendered prompt is too long: + # 1. Get total number of tokens/sentences (depending on the reducer's configuration) + # 2. Splice off doc up to the first half of tokens/sentences + # 3. Render prompt and check whether it fits into context + # 4. If yes: repeat with second doc half. + # 5. If not: repeat from 2., but with split off shard instead of doc. + remaining_doc: Optional[Doc] = doc.copy() + fraction = 0.5 + start_idx = 0 + n_shards = 1 + + while remaining_doc is not None: + fits_in_context = False + shard: Optional[Doc] = None + end_idx = -1 + n_tries = 0 + + while fits_in_context is False: + end_idx = start_idx + int(len(remaining_doc) * fraction) + shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) + fits_in_context = ( + n_tok_est(render_template(shard, len(shards), i_doc, n_shards)) + * buffer_frac + <= context_length + ) + fraction /= 2 + n_tries += 1 + + # If prompt is too large even with shard of a single token, raise error - we can't shard any more + # than this. This is an edge case and will most likely never occur. + if len(shard) == 1 and not fits_in_context: + raise ValueError( + "Prompt size doesn't allow for the inclusion for shard of length 1. Please " + "review your prompt and reduce its size." + ) + + assert shard is not None + shards.append(shard) + fraction = 1 + n_shards = max(len(shards) + round(1 / fraction), 1) + start_idx = end_idx + # Set remaining_doc to None if shard contains all of it, i. e. entire original doc has been processed. + remaining_doc = ( + doc[end_idx:].as_doc(copy_user_data=True) + if shard.text != remaining_doc.text + else None + ) + + return shards + + else: + return [doc] + + return map_doc_to_shards diff --git a/spacy_llm/tests/conftest.py b/spacy_llm/tests/conftest.py index 2eda3409..7a64a074 100644 --- a/spacy_llm/tests/conftest.py +++ b/spacy_llm/tests/conftest.py @@ -42,7 +42,7 @@ def pytest_collection_modifyitems(config, items): @registry.llm_models("test.NoOpModel.v1") def noop_factory(output: str = ""): - def noop(prompts: Iterable[str]) -> Iterable[str]: - return [output] * len(list(prompts)) + def noop(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + return [[output]] * len(list(prompts)) return noop diff --git a/spacy_llm/tests/models/test_anthropic.py b/spacy_llm/tests/models/test_anthropic.py index eb366205..d0bfa794 100644 --- a/spacy_llm/tests/models/test_anthropic.py +++ b/spacy_llm/tests/models/test_anthropic.py @@ -20,13 +20,16 @@ def test_anthropic_api_response_is_correct(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" num_prompts = 3 - responses = anthropic(prompts=[prompt] * num_prompts) + responses = anthropic(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external @@ -47,6 +50,7 @@ def test_anthropic_api_response_when_error(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) @@ -69,4 +73,5 @@ def test_anthropic_error_unsupported_model(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) diff --git a/spacy_llm/tests/models/test_cohere.py b/spacy_llm/tests/models/test_cohere.py index 5d1db35f..dfcb432a 100644 --- a/spacy_llm/tests/models/test_cohere.py +++ b/spacy_llm/tests/models/test_cohere.py @@ -18,12 +18,15 @@ def test_cohere_api_response_is_correct(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" num_prompts = 3 # arbitrary number to check multiple inputs - responses = cohere(prompts=[prompt] * num_prompts) + responses = cohere(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external @@ -44,13 +47,16 @@ def test_cohere_api_response_n_generations(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" num_prompts = 3 - responses = cohere(prompts=[prompt] * num_prompts) + responses = cohere(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external @@ -69,6 +75,7 @@ def test_cohere_api_response_when_error(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) @@ -86,4 +93,5 @@ def test_cohere_error_unsupported_model(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) diff --git a/spacy_llm/tests/models/test_dolly.py b/spacy_llm/tests/models/test_dolly.py index 4b70179d..6a6dc32f 100644 --- a/spacy_llm/tests/models/test_dolly.py +++ b/spacy_llm/tests/models/test_dolly.py @@ -27,7 +27,6 @@ [components.llm] factory = "llm" -save_io = True [components.llm.task] @llm_tasks = "spacy.NoOp.v1" @@ -47,8 +46,8 @@ def test_init(): doc = nlp("This is a test.") nlp.get_pipe("llm")._model.get_model_names() torch.cuda.empty_cache() - assert not doc.user_data["llm_io"]["llm"]["response"].startswith( - doc.user_data["llm_io"]["llm"]["prompt"] + assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith( + doc.user_data["llm_io"]["llm"]["prompt"][0] ) diff --git a/spacy_llm/tests/models/test_falcon.py b/spacy_llm/tests/models/test_falcon.py index 6638975b..0d3f8554 100644 --- a/spacy_llm/tests/models/test_falcon.py +++ b/spacy_llm/tests/models/test_falcon.py @@ -46,8 +46,8 @@ def test_init(): nlp.add_pipe("llm", config=cfg) doc = nlp("This is a test.") torch.cuda.empty_cache() - assert not doc.user_data["llm_io"]["llm"]["response"].startswith( - doc.user_data["llm_io"]["llm"]["prompt"] + assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith( + doc.user_data["llm_io"]["llm"]["prompt"][0] ) diff --git a/spacy_llm/tests/models/test_langchain.py b/spacy_llm/tests/models/test_langchain.py index 57e984dc..0363fcd5 100644 --- a/spacy_llm/tests/models/test_langchain.py +++ b/spacy_llm/tests/models/test_langchain.py @@ -31,7 +31,8 @@ def langchain_model_reg_handles() -> List[str]: def test_initialization(): """Test initialization and simple run""" nlp = spacy.blank("en") - nlp.add_pipe("llm", config=PIPE_CFG) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp.add_pipe("llm", config=PIPE_CFG) nlp("This is a test.") @@ -57,5 +58,6 @@ def test_initialization_azure_openai(): } nlp = spacy.blank("en") - nlp.add_pipe("llm", config=_pipe_cfg) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp.add_pipe("llm", config=_pipe_cfg) nlp("This is a test.") diff --git a/spacy_llm/tests/models/test_openllama.py b/spacy_llm/tests/models/test_openllama.py index efb1c2d3..f42d94dc 100644 --- a/spacy_llm/tests/models/test_openllama.py +++ b/spacy_llm/tests/models/test_openllama.py @@ -45,8 +45,8 @@ def test_init(): nlp.add_pipe("llm", config=_PIPE_CFG) doc = nlp("This is a test.") torch.cuda.empty_cache() - assert not doc.user_data["llm_io"]["llm"]["response"].startswith( - doc.user_data["llm_io"]["llm"]["prompt"] + assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith( + doc.user_data["llm_io"]["llm"]["prompt"][0] ) @@ -60,8 +60,8 @@ def test_init_with_set_config(): nlp.add_pipe("llm", config=cfg) doc = nlp("This is a test.") torch.cuda.empty_cache() - assert not doc.user_data["llm_io"]["llm"]["response"].startswith( - doc.user_data["llm_io"]["llm"]["prompt"] + assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith( + doc.user_data["llm_io"]["llm"]["prompt"][0] ) diff --git a/spacy_llm/tests/models/test_rest.py b/spacy_llm/tests/models/test_rest.py index dc0210b7..305732c6 100644 --- a/spacy_llm/tests/models/test_rest.py +++ b/spacy_llm/tests/models/test_rest.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import copy import re -from typing import Iterable +from typing import Iterable, Optional, Tuple import pytest import spacy @@ -22,14 +22,17 @@ class _CountTask: _PROMPT_TEMPLATE = "Count the number of characters in this string: '{text}'." - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: for doc in docs: - yield _CountTask._PROMPT_TEMPLATE.format(text=doc.text) + yield [_CountTask._PROMPT_TEMPLATE.format(text=doc.text)], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - return docs + # Grab the first shard per doc + return [list(shards_for_doc)[0] for shards_for_doc in shards] @property def prompt_template(self) -> str: @@ -120,7 +123,8 @@ def test_azure_openai(deployment_name: str): "@llm_models": "spacy.Azure.v1", "base_url": "https://explosion.openai.azure.com/", "model_type": "completions", - "name": deployment_name, + "deployment_name": deployment_name, + "name": deployment_name.replace("35", "3.5"), }, "task": {"@llm_tasks": "spacy.NoOp.v1"}, "save_io": True, diff --git a/spacy_llm/tests/models/test_stablelm.py b/spacy_llm/tests/models/test_stablelm.py index b3b09830..e9edab4b 100644 --- a/spacy_llm/tests/models/test_stablelm.py +++ b/spacy_llm/tests/models/test_stablelm.py @@ -49,8 +49,8 @@ def test_init(name: str): nlp.add_pipe("llm", config=cfg) doc = nlp("This is a test.") torch.cuda.empty_cache() - assert not doc.user_data["llm_io"]["llm"]["response"].startswith( - doc.user_data["llm_io"]["llm"]["prompt"] + assert not doc.user_data["llm_io"]["llm"]["response"][0].startswith( + doc.user_data["llm_io"]["llm"]["prompt"][0] ) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 0a3a3f22..1c8e1efc 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -2,7 +2,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Optional, Tuple import pytest import spacy @@ -18,7 +18,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.registry import registry from spacy_llm.tasks import _LATEST_TASKS, make_noop_task -from spacy_llm.tasks.noop import _NOOP_PROMPT +from spacy_llm.tasks.noop import _NOOP_PROMPT, ShardingNoopTask from ...cache import BatchCache from ...registry.reader import fewshot_reader @@ -52,11 +52,20 @@ def test_llm_init(nlp): @pytest.mark.parametrize("n_process", [1, 2]) -def test_llm_pipe(nlp: Language, n_process: int): +@pytest.mark.parametrize("shard", [True, False]) +def test_llm_pipe(noop_config: Dict[str, Any], n_process: int, shard: bool): """Test call .pipe().""" + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpNoShards.v1"}}} + if not shard + else noop_config, + ) ops = get_current_ops() if not isinstance(ops, NumpyOps) and n_process != 1: pytest.skip("Only test multiple processes on CPU") + docs = list( nlp.pipe(texts=["This is a test", "This is another test"], n_process=n_process) ) @@ -64,12 +73,13 @@ def test_llm_pipe(nlp: Language, n_process: int): for doc in docs: llm_io = doc.user_data["llm_io"] - - assert llm_io["llm"]["prompt"] == _NOOP_PROMPT - assert llm_io["llm"]["response"] == _NOOP_RESPONSE + assert llm_io["llm"]["prompt"] == ([_NOOP_PROMPT] if shard else _NOOP_PROMPT) + assert llm_io["llm"]["response"] == ( + [_NOOP_RESPONSE] if shard else _NOOP_RESPONSE + ) -@pytest.mark.parametrize("n_process", [1, 2]) +@pytest.mark.parametrize("n_process", [2]) def test_llm_pipe_with_cache(tmp_path: Path, n_process: int): """Test call .pipe() with pre-cached docs""" ops = get_current_ops() @@ -114,24 +124,26 @@ def test_llm_pipe_empty(nlp): def test_llm_serialize_bytes(): - llm = LLMWrapper( - task=make_noop_task(), - save_io=False, - model=None, # type: ignore - cache=BatchCache(path=None, batch_size=0, max_batches_in_mem=0), - vocab=None, # type: ignore - ) + with pytest.warns(UserWarning, match="Task supports sharding"): + llm = LLMWrapper( + task=make_noop_task(), + save_io=False, + model=None, # type: ignore + cache=BatchCache(path=None, batch_size=0, max_batches_in_mem=0), + vocab=None, # type: ignore + ) llm.from_bytes(llm.to_bytes()) def test_llm_serialize_disk(): - llm = LLMWrapper( - task=make_noop_task(), - save_io=False, - model=None, # type: ignore - cache=BatchCache(path=None, batch_size=0, max_batches_in_mem=0), - vocab=None, # type: ignore - ) + with pytest.warns(UserWarning, match="Task supports sharding"): + llm = LLMWrapper( + task=make_noop_task(), + save_io=False, + model=None, # type: ignore + cache=BatchCache(path=None, batch_size=0, max_batches_in_mem=0), + vocab=None, # type: ignore + ) with spacy.util.make_tempdir() as tmp_dir: llm.to_disk(tmp_dir / "llm") @@ -157,13 +169,16 @@ class NoopTask_Incorrect: def __init__(self): pass - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[int]: - return [0] * len(list(docs)) + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[int], Iterable[Doc]]]: + for doc in docs: + yield [0], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[float] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[float]] ) -> Iterable[Doc]: - return docs + return list(shards)[0] nlp = spacy.blank("en") with pytest.warns(UserWarning) as record: @@ -172,13 +187,13 @@ def parse_responses( assert len(record) == 2 assert ( str(record[0].message) - == "Type returned from `task.generate_prompts()` (`typing.Iterable[int]`) doesn't match type " + == "First type in value returned from `task.generate_prompts()` (`typing.Iterable[int]`) doesn't match type " "expected by `model` (`typing.Iterable[str]`)." ) assert ( str(record[1].message) - == "Type returned from `model` (`typing.Iterable[str]`) doesn't match type " - "expected by `task.parse_responses()` (`typing.Iterable[float]`)." + == "Type returned from `model` (`typing.Iterable[typing.Iterable[str]]`) doesn't match type expected by " + "`task.parse_responses()` (`typing.Iterable[typing.Iterable[float]]`)." ) # Run with disabled type consistency validation. @@ -313,7 +328,8 @@ def test_llm_task_factories(): @llm_models = "test.NoOpModel.v1" """ config = Config().from_str(cfg_string) - assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + assemble_from_config(config) def test_llm_task_factories_el(tmp_path): @@ -362,7 +378,8 @@ def test_llm_task_factories_el(tmp_path): }, ) build_el_pipeline(nlp_path=tmp_path, desc_path=tmp_path / "desc.csv") - assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + assemble_from_config(config) @pytest.mark.external @@ -393,3 +410,63 @@ def test_llm_task_factories_ner(): assert len(doc.ents) > 0 for ent in doc.ents: assert ent.label_ in ["PER", "ORG", "LOC"] + + +@pytest.mark.parametrize("shard", [True, False]) +def test_llm_custom_data(noop_config: Dict[str, Any], shard: bool): + """Test whether custom doc data is preserved.""" + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpNoShards.v1"}}} + if not shard + else noop_config, + ) + + doc = nlp.make_doc("This is a test") + if not Doc.has_extension("test"): + Doc.set_extension("test", default=None) + doc._.test = "Test" + doc.user_data["test"] = "Test" + + doc = nlp(doc) + assert doc._.test == "Test" + assert doc.user_data["test"] == "Test" + + +def test_llm_custom_data_overwrite(noop_config: Dict[str, Any]): + """Test whether custom doc data is overwritten as expected.""" + + class NoopTaskWithCustomData(ShardingNoopTask): + def parse_responses( + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] + ) -> Iterable[Doc]: + docs = super().parse_responses(shards, responses) + for doc in docs: + doc._.test = "Test 2" + doc.user_data["test"] = "Test 2" + return docs + + @registry.llm_tasks("spacy.NoOpCustomData.v1") + def make_noopnoshards_task(): + return NoopTaskWithCustomData() + + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpCustomData.v1"}}}, + ) + doc = nlp.make_doc("This is a test") + for extension in ("test", "test_nooverwrite"): + if not Doc.has_extension(extension): + Doc.set_extension(extension, default=None) + doc._.test = "Test" + doc._.test_nooverwrite = "Test" + doc.user_data["test"] = "Test" + doc.user_data["test_nooverwrite"] = "Test" + + doc = nlp(doc) + assert doc._.test == "Test 2" + assert doc.user_data["test"] == "Test 2" + assert doc._.test_nooverwrite == "Test" + assert doc.user_data["test_nooverwrite"] == "Test" diff --git a/spacy_llm/tests/sharding/__init__.py b/spacy_llm/tests/sharding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py new file mode 100644 index 00000000..19fc17c4 --- /dev/null +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -0,0 +1,284 @@ +import numbers +from pathlib import Path + +import pytest +from confection import Config +from spacy.pipeline import EntityLinker +from spacy.tokens import Span + +from spacy_llm.tests.compat import has_openai_key +from spacy_llm.util import assemble_from_config + +from .util import ShardingCountTask # noqa: F401 + +_CONTEXT_LENGTH = 20 +_TEXT = "Do one thing every day that scares you. The only thing we have to fear is fear itself." + + +@pytest.fixture +def config(): + return Config().from_str( + f""" + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + save_io = True + + [components.llm.task] + @llm_tasks = "spacy.CountWithSharding.v1" + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v3" + context_length = {_CONTEXT_LENGTH} + """ + ) + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_count(config): + """Tests whether task shards data as expected.""" + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "(and nothing else): '" + prompts = [ + pr[pr.index(marker) + len(marker) : -1] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + responses = [int(r) for r in doc.user_data["llm_io"]["llm"]["response"]] + assert prompts == [ + "Do one thing every day ", + "that scares you", + ". The only ", + "thing we have to ", + "fear is fear itself.", + ] + assert all( + [response == len(pr.split()) for response, pr in zip(responses, prompts)] + ) + assert sum(responses) == doc.user_data["count"] + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_lemma(config): + context_length = 120 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Lemma.v1"} + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "to be lemmatized:\n'''\n" + prompts = [ + pr[pr.index(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + # Make sure lemmas are set (somme might not be because the LLM didn't return parsable a response). + assert any([t.lemma != 0 for t in doc]) + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_ner(config): + context_length = 265 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.NER.v3", + "labels": ["LOCATION"], + } + nlp = assemble_from_config(config) + + doc = nlp(_TEXT + " Paris is a city.") + marker = "Paragraph: " + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) + assert prompts == [ + "Do one thing every day that scares you. The only thing ", + "we have to fear is fear itself. Paris is a city.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_rel(config): + context_length = 100 + config["nlp"]["pipeline"] = ["ner", "llm"] + config["components"]["ner"] = {"source": "en_core_web_md"} + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.REL.v1", + "labels": "LivesIn,Visits", + } + config["initialize"] = {"vectors": "en_core_web_md"} + nlp = assemble_from_config(config) + + doc = nlp("Joey rents a place in New York City, which is in North America.") + marker = "Text:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) + assert hasattr(doc._, "rel") and len(doc._.rel) + assert prompts == [ + "Joey[ENT0:PERSON] rents a place in New York City", + "[ENT1:GPE], which is in North America[ENT2:LOC].", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_sentiment(config): + context_length = 50 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Sentiment.v1"} + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "Text:\n'''\n" + prompts = [ + pr[pr.index(marker) + len(marker) : pr.rindex("\n'''\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert hasattr(doc._, "sentiment") and isinstance(doc._.sentiment, numbers.Number) + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_spancat(config): + context_length = 265 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.SpanCat.v3", + "labels": ["LOCATION"], + } + nlp = assemble_from_config(config) + + doc = nlp(_TEXT + " Paris is a city.") + marker = "Paragraph: " + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.spans.data["sc"]) + assert prompts == [ + "Do one thing every day that ", + "scares you. The only thing we have to ", + "fear is fear itself. Paris is a city.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 3 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_summary(config): + context_length = 50 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Summarization.v1"} + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "needs to be summarized:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\n'''\nSummary:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert hasattr(doc._, "summary") and doc._.summary + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_textcat(config): + context_length = 100 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.TextCat.v3", + "labels": "RECIPE", + "exclusive_classes": True, + } + nlp = assemble_from_config(config) + + doc = nlp( + "Fry an egg in a pan. Scramble it. Add some salt, pepper and truffle oil." + ) + marker = "Text:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.cats) == 1 and "RECIPE" in doc.cats + assert prompts == [ + "Fry an egg in ", + "a pan. Scramble it. Add ", + "some salt, pepper and truffle oil.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 3 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_entity_linker(config): + context_length = 290 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.EntityLinker.v1"} + config["initialize"] = { + "components": { + "llm": { + "candidate_selector": { + "@llm_misc": "spacy.CandidateSelector.v1", + "kb_loader": { + "@llm_misc": "spacy.KBFileLoader.v1", + "path": "${paths.el_kb}", + }, + } + } + } + } + config["paths"] = { + "el_kb": str( + Path(__file__).resolve().parent.parent / "tasks" / "misc" / "el_kb_data.yml" + ) + } + nlp = assemble_from_config(config) + + doc = nlp.make_doc("Alice goes to Boston to see the Boston Celtics game.") + doc.ents = [ + Span(doc=doc, start=3, end=4, label="LOC"), # Q100 + Span(doc=doc, start=7, end=9, label="ORG"), # Q131371 + ] + doc = nlp(doc) + marker = "TEXT: \n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\n'''")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) == 2 + assert all([ent.kb_id_ != EntityLinker.NIL for ent in doc.ents]) + assert prompts == ["Alice goes to *Boston* to ", "see the *Boston Celtics* game."] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 diff --git a/spacy_llm/tests/sharding/util.py b/spacy_llm/tests/sharding/util.py new file mode 100644 index 00000000..87b21f37 --- /dev/null +++ b/spacy_llm/tests/sharding/util.py @@ -0,0 +1,83 @@ +import warnings +from typing import Iterable, List, Optional + +from spacy.tokens import Doc +from spacy.training import Example + +from spacy_llm.compat import Self +from spacy_llm.registry import registry +from spacy_llm.tasks import BuiltinTask +from spacy_llm.tasks.util.sharding import make_shard_mapper +from spacy_llm.ty import FewshotExample, ShardReducer + + +def parse_responses( + task: "ShardingCountTask", + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[int]]: + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[int] = [] + for response, shard in zip(responses_for_doc, shards_for_doc): + results_for_doc.append(int(response)) + + yield results_for_doc + + +def reduce_shards_to_doc(task: "ShardingCountExample", shards: Iterable[Doc]) -> Doc: + shards = list(shards) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping unsupported user data", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + doc.user_data["count"] = sum([shard.user_data["count"] for shard in shards]) + return doc + + +class ShardingCountExample(FewshotExample): + @classmethod + def generate(cls, example: Example, task: "ShardingCountTask") -> Optional[Self]: + return None + + +@registry.llm_tasks("spacy.CountWithSharding.v1") +class ShardingCountTask(BuiltinTask): + _PROMPT_TEMPLATE = ( + "Reply with the number of words in this string (and nothing else): '{{ text }}'" + ) + + def __init__(self): + assert isinstance(reduce_shards_to_doc, ShardReducer) + super().__init__( + parse_responses=parse_responses, + prompt_example_type=ShardingCountExample, + template=self._PROMPT_TEMPLATE, + prompt_examples=[], + shard_mapper=make_shard_mapper(), + shard_reducer=reduce_shards_to_doc, + ) + + def parse_responses( + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] + ) -> Iterable[Doc]: + shards_teed = self._tee_2d_iterable(shards, 2) + + for shards_for_doc, counts_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, count in zip(shards_for_doc, counts_for_doc): + shard.user_data["count"] = count + + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] + + @property + def prompt_template(self) -> str: + return self._PROMPT_TEMPLATE + + @property + def _cfg_keys(self) -> List[str]: + return [] diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 6d8ff727..3d9c133a 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -15,7 +15,7 @@ from spacy_llm.registry import strip_normalizer from spacy_llm.tasks.ner import NERTask, make_ner_task_v2 from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...compat import has_openai_key @@ -195,7 +195,7 @@ def test_ner_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = orig_config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) @@ -329,7 +329,7 @@ def test_ner_zero_shot_task(text, response, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -388,7 +388,7 @@ def test_ner_labels(response, normalizer, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -437,7 +437,7 @@ def test_ner_alignment(response, alignment_mode, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -488,7 +488,7 @@ def test_ner_matching(response, case_sensitive, single_match, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -504,7 +504,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v2(labels=labels, examples=None) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -547,7 +547,7 @@ def test_jinja_template_rendering_with_examples(examples_path): examples = fewshot_reader(examples_path) llm_ner = make_ner_task_v2(labels=labels, examples=examples) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -611,7 +611,7 @@ def test_jinja_template_rendering_with_label_definitions(): "LOC": "Location definition", }, ) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -664,7 +664,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v2(labels=labels, template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -707,7 +707,8 @@ def noop_config(): @pytest.mark.parametrize("n_detections", [0, 1, 2]) def test_ner_scoring(noop_config, n_detections): config = Config().from_str(noop_config) - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -723,7 +724,6 @@ def test_ner_scoring(noop_config, n_detections): examples.append(Example(predicted, reference)) scores = nlp.evaluate(examples) - assert scores["ents_p"] == n_detections / 2 @@ -732,7 +732,8 @@ def test_ner_init(noop_config, n_prompt_examples: int): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -776,9 +777,9 @@ def test_ner_init(noop_config, n_prompt_examples: int): def test_ner_serde(noop_config): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "per": "PER"} @@ -801,8 +802,9 @@ def test_ner_to_disk(noop_config, tmp_path: Path): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "per": "PER"} diff --git a/spacy_llm/tests/tasks/legacy/test_spancat.py b/spacy_llm/tests/tasks/legacy/test_spancat.py index 124dd94d..87065d0e 100644 --- a/spacy_llm/tests/tasks/legacy/test_spancat.py +++ b/spacy_llm/tests/tasks/legacy/test_spancat.py @@ -12,7 +12,7 @@ from spacy_llm.registry import fewshot_reader, lowercase_normalizer, strip_normalizer from spacy_llm.tasks.spancat import SpanCatTask, make_spancat_task_v2 from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...compat import has_openai_key @@ -85,7 +85,7 @@ def test_spancat_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = orig_config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) @@ -201,7 +201,7 @@ def test_spancat_zero_shot_task(text, response, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -260,7 +260,7 @@ def test_spancat_labels(response, normalizer, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -309,7 +309,7 @@ def test_spancat_alignment(response, alignment_mode, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -360,7 +360,7 @@ def test_spancat_matching(response, case_sensitive, single_match, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -376,7 +376,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_spancat = make_spancat_task_v2(labels=labels, examples=None) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -420,7 +420,7 @@ def test_jinja_template_rendering_with_examples(examples_path): examples = fewshot_reader(examples_path) llm_spancat = make_spancat_task_v2(labels=labels, examples=examples) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -513,7 +513,8 @@ def noop_config(): @pytest.mark.parametrize("n_detections", [0, 1, 2]) def test_spancat_scoring(noop_config, n_detections): config = Config().from_str(noop_config) - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -529,7 +530,6 @@ def test_spancat_scoring(noop_config, n_detections): examples.append(Example(predicted, reference)) scores = nlp.evaluate(examples) - assert scores["spans_sc_p"] == n_detections / 2 @@ -537,7 +537,8 @@ def test_spancat_scoring(noop_config, n_detections): def test_spancat_init(noop_config, n_prompt_examples: bool): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -583,8 +584,9 @@ def test_spancat_serde(noop_config): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "per": "PER"} diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 625d9105..6101236b 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -353,7 +353,7 @@ def make_doc() -> Doc: doc = nlp(make_doc()) assert ( f"- For *Foo*:\n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" - in doc.user_data["llm_io"]["llm"]["prompt"] + in doc.user_data["llm_io"]["llm"]["prompt"][0] ) assert doc.ents[0].kb_id_ == EntityLinker.NIL # Sometimes GPT-3.5 doesn't manage to include the NIL prediction, in which case all entities are set to NIL. @@ -427,7 +427,7 @@ def test_jinja_template_rendering_without_examples(tmp_path): ) ) el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) - prompt = list(el_task.generate_prompts([doc]))[0] + prompt = list(el_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip().replace(" \n", "\n") @@ -497,7 +497,7 @@ def test_jinja_template_rendering_with_examples(examples_path, tmp_path): ) ) el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) - prompt = list(el_task.generate_prompts([doc]))[0] + prompt = list(el_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip().replace(" \n", "\n") @@ -601,7 +601,7 @@ def test_external_template_actually_loads(tmp_path): el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) assert ( - list(el_task.generate_prompts([doc]))[0].strip() + list(el_task.generate_prompts([doc]))[0][0][0].strip() == f""" This is a test entity linking template. Here is the text: {text} @@ -620,7 +620,8 @@ def test_el_init(noop_config, n_prompt_examples: int, tmp_path): }, ) build_el_pipeline(nlp_path=tmp_path, desc_path=tmp_path / "desc.csv") - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -678,9 +679,16 @@ def test_ent_highlighting(): ] assert ( - EntityLinkerTask.highlight_ents_in_text(doc) + EntityLinkerTask.highlight_ents_in_doc(doc).text == "Alice goes to *Boston* to see the *Boston Celtics* game." ) + assert ( + EntityLinkerTask.unhighlight_ents_in_doc( + EntityLinkerTask.highlight_ents_in_doc(doc) + ).text + == doc.text + == text + ) @pytest.mark.external diff --git a/spacy_llm/tests/tasks/test_lemma.py b/spacy_llm/tests/tasks/test_lemma.py index ec30e3ce..d82cd087 100644 --- a/spacy_llm/tests/tasks/test_lemma.py +++ b/spacy_llm/tests/tasks/test_lemma.py @@ -141,8 +141,8 @@ def test_lemma_config(cfg_string, request): @pytest.mark.parametrize( "cfg_string", [ - "zeroshot_cfg_string", - "fewshot_cfg_string", + # "zeroshot_cfg_string", + # "fewshot_cfg_string", "ext_template_cfg_string", ], ) @@ -199,7 +199,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc(text) lemma_task = make_lemma_task(examples=None) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -242,7 +242,7 @@ def test_jinja_template_rendering_with_examples(examples_path): doc = nlp.make_doc(text) lemma_task = make_lemma_task(examples=fewshot_reader(examples_path)) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -334,7 +334,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc(text) lemma_task = make_lemma_task(template=template) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == f""" @@ -347,7 +347,8 @@ def test_external_template_actually_loads(): @pytest.mark.parametrize("n_prompt_examples", [-1, 0, 1, 2]) def test_lemma_init(noop_config, n_prompt_examples: int): config = Config().from_str(noop_config) - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] pred_words_1 = ["Alice", "works", "all", "evenings"] diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index 3c180048..e8782d08 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -20,7 +20,7 @@ from spacy_llm.tasks.span import SpanReason from spacy_llm.tasks.span.parser import _extract_span_reasons_cot from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -205,7 +205,7 @@ def test_ner_config(config: Config): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) @@ -395,7 +395,7 @@ def test_ner_labels( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -451,7 +451,7 @@ def test_ner_alignment( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -502,7 +502,7 @@ def test_ner_matching( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -517,7 +517,7 @@ def test_jinja_template_rendering_without_examples(): nlp = spacy.blank("en") doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v3(labels=labels) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -562,7 +562,7 @@ def test_jinja_template_rendering_with_examples(examples_dir: Path, examples_fil doc = nlp.make_doc("Alice and Bob went to the supermarket") examples = fewshot_reader(examples_dir / examples_file) llm_ner = make_ner_task_v3(examples=examples, labels=labels) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -609,7 +609,7 @@ def test_jinja_template_rendering_with_label_definitions( "LOC": "Location definition", }, ) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -690,7 +690,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v3(examples=[], labels=labels, template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert prompt.strip().startswith("Here's the test template for the tests and stuff") @@ -723,7 +723,8 @@ def test_ner_init(noop_config: str, n_prompt_examples: int): config = Config().from_str(noop_config) config["components"]["llm"]["task"]["labels"] = ["PER", "LOC"] config["components"]["llm"]["task"]["examples"] = [] - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] for text in [ @@ -766,8 +767,9 @@ def test_ner_init(noop_config: str, n_prompt_examples: int): def test_ner_serde(noop_config: str): config = Config().from_str(noop_config) - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "per": "PER"} @@ -789,9 +791,9 @@ def test_ner_serde(noop_config: str): def test_ner_to_disk(noop_config: str, tmp_path: Path): config = Config().from_str(noop_config) - - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "org": "ORG", "per": "PER"} @@ -935,7 +937,7 @@ def test_regression_span_task_response_parse( span_reasons = _extract_span_reasons_cot(ner_task, response) assert len(span_reasons) == len(gold_ents) - docs = list(ner_task.parse_responses([example_doc], [response])) + docs = list(ner_task.parse_responses([[example_doc]], [[response]])) assert len(docs) == 1 doc = docs[0] @@ -964,7 +966,7 @@ def test_regression_span_task_comma( ner_task = make_ner_task_v3(examples=[], labels=["ORG", "LOC"]) span_reasons = _extract_span_reasons_cot(ner_task, response) assert len(span_reasons) == len(gold_ents) - docs = list(ner_task.parse_responses([example_doc], [response])) + docs = list(ner_task.parse_responses([[example_doc]], [[response]])) assert len(docs) == 1 doc = docs[0] pred_ents = [(ent.text, ent.label_) for ent in doc.ents] diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index 3650114e..258824d4 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -10,7 +10,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.tasks.rel import DEFAULT_REL_TEMPLATE, RelationItem, RELTask -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...tasks import make_rel_task @@ -122,7 +122,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) task = pipe.task labels = orig_config["components"]["llm"]["task"]["labels"] @@ -135,7 +135,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) +@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable @@ -157,7 +157,8 @@ def test_rel_init(noop_config, n_prompt_examples: int): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -201,9 +202,10 @@ def test_rel_serde(noop_config, tmp_path: Path): config = Config().from_str(noop_config) del config["components"]["llm"]["task"]["labels"] - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) - nlp3 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) + nlp3 = assemble_from_config(config) labels = {"livesin": "LivesIn", "visits": "Visits"} @@ -250,9 +252,9 @@ def test_incorrect_indexing(): len( list( task._parse_responses( - task, [doc], ['{"dep": 0, "dest": 0, "relation": "LivesIn"}'] + task, [[doc]], [['{"dep": 0, "dest": 0, "relation": "LivesIn"}']] ) - )[0] + )[0][0] ) == 1 ) @@ -260,9 +262,9 @@ def test_incorrect_indexing(): len( list( task._parse_responses( - task, [doc], ['{"dep": 0, "dest": 1, "relation": "LivesIn"}'] + task, [[doc]], [['{"dep": 0, "dest": 1, "relation": "LivesIn"}']] ) - )[0] + )[0][0] ) == 0 ) @@ -284,5 +286,5 @@ def test_labels_in_prompt(request: FixtureRequest): assert ( "Well[ENT0:A] hello[ENT1:B] there[ENT2:C]" - in list(nlp.get_pipe("llm")._task.generate_prompts([doc]))[0] + in list(nlp.get_pipe("llm")._task.generate_prompts([doc]))[0][0][0] ) diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index f61f875a..aac85966 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -131,9 +131,9 @@ def test_sentiment_predict(cfg_string, request): orig_config = Config().from_str(cfg) nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) if cfg_string != "ext_template_cfg_string": - assert nlp("This is horrible.")._.sentiment == 0 + assert nlp("This is horrible.")._.sentiment == 0.0 assert 0 < nlp("This is meh.")._.sentiment <= 0.5 - assert nlp("This is perfect.")._.sentiment == 1 + assert nlp("This is perfect.")._.sentiment == 1.0 @pytest.mark.external @@ -146,7 +146,7 @@ def test_sentiment_predict(cfg_string, request): ("zeroshot_cfg_string", "sentiment_x"), ], ) -def test_lemma_io(cfg_string_field, request): +def test_sentiment_io(cfg_string_field, request): cfg_string, field = cfg_string_field cfg = request.getfixturevalue(cfg_string) orig_config = Config().from_str(cfg) @@ -175,7 +175,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(examples=None) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -209,7 +209,7 @@ def test_jinja_template_rendering_with_examples(examples_path): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(examples=fewshot_reader(examples_path)) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -257,7 +257,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(template=template) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == f""" diff --git a/spacy_llm/tests/tasks/test_spancat.py b/spacy_llm/tests/tasks/test_spancat.py index 97e5bd88..b064c9ef 100644 --- a/spacy_llm/tests/tasks/test_spancat.py +++ b/spacy_llm/tests/tasks/test_spancat.py @@ -14,7 +14,7 @@ from spacy_llm.tasks import make_spancat_task_v3 from spacy_llm.tasks.spancat import SpanCatTask from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -148,7 +148,7 @@ def test_spancat_config(config: Config): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) @@ -257,7 +257,7 @@ def test_spancat_matching_shot_task(text: str, response: str, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -330,7 +330,7 @@ def test_spancat_labels( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -382,7 +382,7 @@ def test_spancat_alignment(response, alignment_mode, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -431,7 +431,7 @@ def test_spancat_matching(response, case_sensitive, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -446,7 +446,7 @@ def test_jinja_template_rendering_without_examples(): nlp = spacy.blank("en") doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_spancat = make_spancat_task_v3(labels=labels) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -501,7 +501,7 @@ def test_jinja_template_rendering_with_examples(examples_path: Path): examples = fewshot_reader(examples_path) llm_spancat = make_spancat_task_v3(labels=labels, examples=examples) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -559,7 +559,8 @@ def test_spancat_init(noop_config: str, n_prompt_examples: bool): config = Config().from_str(noop_config) config["components"]["llm"]["task"]["labels"] = ["PER", "LOC", "DESTINATION"] config["components"]["llm"]["task"]["examples"] = [] - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -603,9 +604,9 @@ def test_spancat_init(noop_config: str, n_prompt_examples: bool): def test_spancat_serde(noop_config): config = Config().from_str(noop_config) - - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) labels = {"loc": "LOC", "per": "PER"} diff --git a/spacy_llm/tests/tasks/test_summarization.py b/spacy_llm/tests/tasks/test_summarization.py index d8912b38..35e24118 100644 --- a/spacy_llm/tests/tasks/test_summarization.py +++ b/spacy_llm/tests/tasks/test_summarization.py @@ -8,7 +8,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.registry import fewshot_reader, file_reader -from spacy_llm.ty import LLMTask +from spacy_llm.ty import ShardingLLMTask from spacy_llm.util import assemble_from_config from ...tasks import make_summarization_task @@ -152,7 +152,7 @@ def test_summarization_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) @pytest.mark.external @@ -253,7 +253,7 @@ def test_jinja_template_rendering_without_examples(example_text): doc = nlp.make_doc(example_text) llm_ner = make_summarization_task(examples=None, max_n_words=10) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -294,7 +294,7 @@ def test_jinja_template_rendering_with_examples(examples_path, example_text): "The provided example 'Life is a quality th...' has a summary of length 28, but `max_n_words` == 20." ), ): - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -343,7 +343,7 @@ def test_external_template_actually_loads(example_text): doc = nlp.make_doc(example_text) llm_ner = make_summarization_task(template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -355,15 +355,17 @@ def test_external_template_actually_loads(example_text): def test_ner_serde(noop_config): config = Config().from_str(noop_config) - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) nlp2.from_bytes(nlp1.to_bytes()) def test_ner_to_disk(noop_config, tmp_path: Path): config = Config().from_str(noop_config) - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) path = tmp_path / "model" nlp1.to_disk(path) diff --git a/spacy_llm/tests/tasks/test_textcat.py b/spacy_llm/tests/tasks/test_textcat.py index b6b0d641..6e7468dd 100644 --- a/spacy_llm/tests/tasks/test_textcat.py +++ b/spacy_llm/tests/tasks/test_textcat.py @@ -13,7 +13,7 @@ from spacy_llm.registry import fewshot_reader, file_reader, lowercase_normalizer from spacy_llm.registry import registry from spacy_llm.tasks.textcat import TextCatTask, make_textcat_task_v3 -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -204,7 +204,7 @@ def test_textcat_config(task, cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = split_labels(labels) task = pipe.task @@ -318,7 +318,7 @@ def test_textcat_binary_labels_are_correct(text, response, expected_score): nlp = spacy.blank("en") doc = nlp(text) - pred = list(llm_textcat.parse_responses([doc], [response]))[0] + pred = list(llm_textcat.parse_responses([[doc]], [[response]]))[0] assert list(pred.cats.keys())[0] == label assert list(pred.cats.values())[0] == expected_score @@ -350,7 +350,7 @@ def test_textcat_multilabel_labels_are_correct( ) nlp = spacy.blank("en") doc = nlp.make_doc(text) - pred = list(llm_textcat.parse_responses([doc], [response]))[0] + pred = list(llm_textcat.parse_responses([[doc]], [[response]]))[0] # Take only those that have scores pred_cats = [cat for cat, score in pred.cats.items() if score == 1.0] assert set(pred_cats) == set(expected) @@ -380,7 +380,7 @@ def test_jinja_template_rendering_with_examples_for_binary(examples_path, binary examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -446,7 +446,7 @@ def test_jinja_template_rendering_with_examples_for_multilabel_exclusive( examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -513,7 +513,7 @@ def test_jinja_template_rendering_with_examples_for_multilabel_nonexclusive( examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -592,7 +592,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Combine 2 cloves of garlic with soy sauce") llm_textcat = make_textcat_task_v3(labels=labels, template=template) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -636,16 +636,17 @@ def test_external_template_actually_loads(): def test_textcat_scoring(zeroshot_cfg_string, n_insults): @registry.llm_models("Dummy") def factory(): - def b(prompts: Iterable[str]) -> Iterable[str]: + def b(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: for _ in prompts: - yield "POS" + yield ["POS"] return b config = Config().from_str(zeroshot_cfg_string) config["components"]["llm"]["model"] = {"@llm_models": "Dummy"} config["components"]["llm"]["task"]["labels"] = "Insult" - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -659,9 +660,7 @@ def b(prompts: Iterable[str]) -> Iterable[str]: examples.append(Example(predicted, reference)) scores = nlp.evaluate(examples) - pos = n_insults / len(INSULTS) - assert scores["cats_micro_p"] == pos assert not n_insults or scores["cats_micro_r"] == 1 @@ -680,7 +679,7 @@ def test_jinja_template_rendering_with_label_definitions(multilabel_excl): }, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -744,7 +743,8 @@ def test_textcat_init( config = Config().from_str(noop_config) if init_from_config: config["initialize"] = {"components": {"llm": {"labels": ["Test"]}}} - nlp = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble_from_config(config) examples = [] @@ -788,10 +788,10 @@ def test_textcat_init( def test_textcat_serde(noop_config, tmp_path: Path): config = Config().from_str(noop_config) - - nlp1 = assemble_from_config(config) - nlp2 = assemble_from_config(config) - nlp3 = assemble_from_config(config) + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp1 = assemble_from_config(config) + nlp2 = assemble_from_config(config) + nlp3 = assemble_from_config(config) labels = {"insult": "INSULT", "compliment": "COMPLIMENT"} diff --git a/spacy_llm/tests/test_cache.py b/spacy_llm/tests/test_cache.py index 1522c82c..ef41a494 100644 --- a/spacy_llm/tests/test_cache.py +++ b/spacy_llm/tests/test_cache.py @@ -3,7 +3,7 @@ import re import time from pathlib import Path -from typing import Dict, Iterable +from typing import Dict, Iterable, Optional, Tuple import pytest import spacy @@ -211,11 +211,14 @@ def test_prompt_template_handling(): @registry.llm_tasks("NoPromptTemplate.v1") class NoopTask_NoPromptTemplate: - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: - return [""] * len(list(docs)) + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: + for doc in docs: + yield [""], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: return docs diff --git a/spacy_llm/tests/test_combinations.py b/spacy_llm/tests/test_combinations.py index d36e72ac..651f6e16 100644 --- a/spacy_llm/tests/test_combinations.py +++ b/spacy_llm/tests/test_combinations.py @@ -20,7 +20,7 @@ ["spacy.NER.v1", "spacy.NER.v3", "spacy.TextCat.v1"], ids=["ner.v1", "ner.v3", "textcat"], ) -@pytest.mark.parametrize("n_process", [1]) # , 2 +@pytest.mark.parametrize("n_process", [1, 2]) def test_combinations(model: str, task: str, n_process: int): """Randomly test combinations of models and tasks.""" ops = get_current_ops() @@ -44,12 +44,20 @@ def test_combinations(model: str, task: str, n_process: int): config["task"]["exclusive_classes"] = True nlp = spacy.blank("en") - nlp.add_pipe("llm", config=config) + if model.startswith("langchain"): + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp.add_pipe("llm", config=config) + else: + nlp.add_pipe("llm", config=config) + name, component = nlp.pipeline[0] assert name == "llm" assert isinstance(component, LLMWrapper) nlp("This is a test.") list( - nlp.pipe(["This is a second test", "This is a third test"], n_process=n_process) + nlp.pipe( + ["This is a second test", "This is a third test"], + n_process=n_process, + ) ) diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 69e73a0c..139dedd0 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -17,10 +17,19 @@ _ResponseType = Any _ParsedResponseType = Any -PromptExecutorType = Callable[[Iterable[_PromptType]], Iterable[_ResponseType]] +PromptExecutorType = Callable[ + [Iterable[Iterable[_PromptType]]], Iterable[_ResponseType] +] ExamplesConfigType = Union[ Iterable[Dict[str, Any]], Callable[[], Iterable[Dict[str, Any]]], None ] +NTokenEstimator = Callable[[str], int] +ShardMapper = Callable[ + # Requires doc, doc index, context length and callable for rendering template from doc shard text. + [Doc, int, int, Callable[[Doc, int, int, int], str]], + # Returns each shard as a doc. + Iterable[Doc], +] @runtime_checkable @@ -84,7 +93,34 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: @runtime_checkable -class LLMTask(Protocol): +class ShardingLLMTask(Protocol): + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Tuple[Iterable[_PromptType], Iterable[Doc]]]: + """Generate prompts from docs. + docs (Iterable[Doc]): Docs to generate prompts from. + context_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, + if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. + RETURNS (Iterable[Tuple[Iterable[_PromptType], Iterable[Doc]]]): Iterable with one to n prompts per doc + (multiple prompts in case of multiple shards) and the corresponding shards. The relationship between shard + and prompt is 1:1. + """ + + def parse_responses( + self, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[_ResponseType]], + ) -> Iterable[Doc]: + """ + Parses LLM responses. + docs (Iterable[Iterable[Doc]]): Doc shards to map responses into. + responses ([Iterable[Iterable[_ResponseType]]]): LLM responses. + RETURNS (Iterable[Doc]]): Updated (and fused) docs. + """ + + +@runtime_checkable +class NonshardingLLMTask(Protocol): def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. @@ -102,7 +138,23 @@ def parse_responses( """ -TaskContraT = TypeVar("TaskContraT", bound=LLMTask, contravariant=True) +LLMTask = Union[NonshardingLLMTask, ShardingLLMTask] + +TaskContraT = TypeVar( + "TaskContraT", bound=Union[ShardingLLMTask, NonshardingLLMTask], contravariant=True +) +ShardingTaskContraT = TypeVar( + "ShardingTaskContraT", bound=ShardingLLMTask, contravariant=True +) + + +@runtime_checkable +class ShardReducer(Protocol[ShardingTaskContraT]): + """Generic protocol for tasks' shard reducer.""" + + def __call__(self, task: ShardingTaskContraT, shards: Iterable[Doc]) -> Doc: + """Merges shard to single Doc.""" + ... class FewshotExample(GenericModel, abc.ABC, Generic[TaskContraT]): @@ -120,8 +172,11 @@ class TaskResponseParser(Protocol[TaskContraT]): """Generic protocol for parsing functions with specific tasks.""" def __call__( - self, task: TaskContraT, docs: Iterable[Doc], responses: Iterable[Any] - ) -> Iterable[Any]: + self, + task: TaskContraT, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[Any]], + ) -> Iterable[Iterable[Any]]: ... @@ -149,11 +204,13 @@ def clear(self) -> None: class Cache(Protocol): """Defines minimal set of operations a cache implementiation needs to support.""" - def initialize(self, vocab: Vocab, task: LLMTask) -> None: + def initialize( + self, vocab: Vocab, task: Union[NonshardingLLMTask, ShardingLLMTask] + ) -> None: """ Initialize cache with data not available at construction time. vocab (Vocab): Vocab object. - task (LLMTask): Task. + task (Union[LLMTask, ShardingLLMTask]): Task. """ def add(self, doc: Doc) -> None: @@ -186,25 +243,46 @@ def __getitem__(self, doc: Doc) -> Optional[Doc]: """ -def _do_args_match(out_arg: Iterable, in_arg: Iterable) -> bool: +@runtime_checkable +class ModelWithContextLength(Protocol): + @property + def context_length(self) -> int: + """Provides context length for the corresponding model. + RETURNS (int): Context length for the corresponding model. + """ + + +def _do_args_match(out_arg: Iterable, in_arg: Iterable, nesting_level: int) -> bool: """Compares argument type of Iterables for compatibility. in_arg (Iterable): Input argument. out_arg (Iterable): Output argument. + nesting_level (int): Expected level of nesting in types. E. g. Iterable[Iterable[Any]] has a level of 2, + Iterable[Any] of 1. Note that this is assumed for all sub-types in out_arg and in_arg, as this is sufficient for + the current use case of checking the compatibility of task-to-model and model-to-parser communication flow. RETURNS (bool): True if type variables are of the same length and if type variables in out_arg are a subclass of (or the same class as) the type variables in in_arg. """ assert hasattr(out_arg, "__args__") and hasattr(in_arg, "__args__") + + out_types, in_types = out_arg, in_arg + for level in range(nesting_level): + out_types = ( + out_types.__args__[0] if level < (nesting_level - 1) else out_types.__args__ # type: ignore[attr-defined] + ) + in_types = ( + in_types.__args__[0] if level < (nesting_level - 1) else in_types.__args__ # type: ignore[attr-defined] + ) # Replace Any with object to make issubclass() check work. - out_type_vars = [arg if arg != Any else object for arg in out_arg.__args__] - in_type_vars = [arg if arg != Any else object for arg in in_arg.__args__] + out_types = [arg if arg != Any else object for arg in out_types] + in_types = [arg if arg != Any else object for arg in in_types] - if len(out_type_vars) != len(in_type_vars): + if len(out_types) != len(in_types): return False return all( [ issubclass(out_tv, in_tv) or issubclass(in_tv, out_tv) - for out_tv, in_tv in zip(out_type_vars, in_type_vars) + for out_tv, in_tv in zip(out_types, in_types) ] ) @@ -245,24 +323,46 @@ def _extract_model_call_signature(model: PromptExecutorType) -> Dict[str, Any]: return signature -def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: +def supports_sharding(task: Union[NonshardingLLMTask, ShardingLLMTask]) -> bool: + """Determines task type, as isinstance(instance, Protocol) only checks for method names. This also considers + argument and return types. Raises an exception if task is neither. + Note that this is not as thorough as validate_type_consistency() and relies on clues to determine which task type + a given, type-validated task type is. This doesn't guarantee that a task has valid typing. This method should only + be in conjunction with validate_type_consistency(). + task (Union[LLMTask, ShardingLLMTask]): Task to check. + RETURNS (bool): True if task supports sharding, False if not. + """ + prompt_ret_type = typing.get_type_hints(task.generate_prompts)["return"].__args__[0] + return ( + hasattr(prompt_ret_type, "_name") + and prompt_ret_type._name == "Tuple" + and len(prompt_ret_type.__args__) == 2 + ) + + +def validate_type_consistency( + task: Union[NonshardingLLMTask, ShardingLLMTask], model: PromptExecutorType +) -> None: """Check whether the types of the task and model signatures match. - task (LLMTask): Specified task. + task (ShardingLLMTask): Specified task. model (PromptExecutor): Specified model. """ # Raises an error or prints a warning if something looks wrong/odd. - if not isinstance(task, LLMTask): + # todo update error messages + if not isinstance(task, NonshardingLLMTask): raise ValueError( - f"A task needs to be of type 'LLMTask' but found {type(task)} instead" + f"A task needs to adhere to the interface of either 'LLMTask' or 'ShardingLLMTask', but {type(task)} " + f"doesn't." ) if not hasattr(task, "generate_prompts"): raise ValueError( - "A task needs to have the following method: generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]" + "A task needs to have the following method: generate_prompts(self, docs: Iterable[Doc]) -> " + "Iterable[Tuple[Iterable[Any], Iterable[Doc]]]" ) if not hasattr(task, "parse_responses"): raise ValueError( "A task needs to have the following method: " - "parse_responses(self, docs: Iterable[Doc], responses: Iterable[Any]) -> Iterable[Doc]" + "parse_responses(self, docs: Iterable[Doc], responses: Iterable[Iterable[Any]]) -> Iterable[Doc]" ) type_hints = { @@ -271,9 +371,9 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: "model": _extract_model_call_signature(model), } - parse_input: Optional[Type] = None - model_input: Optional[Type] = None - model_output: Optional[Type] = None + parse_in: Optional[Type] = None + model_in: Optional[Type] = None + model_out: Optional[Type] = None # Validate the 'model' object if not (len(type_hints["model"]) == 2 and "return" in type_hints["model"]): @@ -282,9 +382,9 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: ) for k in type_hints["model"]: if k == "return": - model_output = type_hints["model"][k] + model_out = type_hints["model"][k] else: - model_input = type_hints["model"][k] + model_in = type_hints["model"][k] # validate the 'parse' object if not (len(type_hints["parse"]) == 3 and "return" in type_hints["parse"]): @@ -295,36 +395,53 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: # find the 'prompt_responses' var without assuming its name type_k = type_hints["parse"][k] if type_k != typing.Iterable[Doc]: - parse_input = type_hints["parse"][k] + parse_in = type_hints["parse"][k] - template_output = type_hints["template"]["return"] + template_out = type_hints["template"]["return"] # Check that all variables are Iterables. for var, msg in ( - (template_output, "`task.generate_prompts()` needs to return an `Iterable`."), + (template_out, "`task.generate_prompts()` needs to return an `Iterable`."), ( - model_input, + model_in, "The prompts variable in the 'model' needs to be an `Iterable`.", ), - (model_output, "The `model` function needs to return an `Iterable`."), + (model_out, "The `model` function needs to return an `Iterable`."), ( - parse_input, + parse_in, "`responses` in `task.parse_responses()` needs to be an `Iterable`.", ), ): - if not var != Iterable: + if not (hasattr(var, "_name") and var._name == "Iterable"): raise ValueError(msg) + # Ensure that template/prompt generator output is Iterable of 2-Tuple, the second of which fits doc shards type. + template_out_type = template_out.__args__[0] + if ( + hasattr(template_out_type, "_name") + and template_out_type._name == "Tuple" + and len(template_out_type.__args__) == 2 + ): + has_shards = True + template_out_type = template_out_type.__args__[0] + else: + has_shards = False + # Ensure that the template returns the same type as expected by the model - if not _do_args_match(template_output, model_input): # type: ignore[arg-type] + assert model_in is not None + if not _do_args_match( + template_out_type if has_shards else typing.Iterable[template_out_type], # type: ignore[valid-type] + model_in.__args__[0], + 1, + ): # type: ignore[arg-type] warnings.warn( - f"Type returned from `task.generate_prompts()` (`{template_output}`) doesn't match type expected by " - f"`model` (`{model_input}`)." + f"First type in value returned from `task.generate_prompts()` (`{template_out_type}`) doesn't match type " + f"expected by `model` (`{model_in.__args__[0]}`)." ) # Ensure that the parser expects the same type as returned by the model - if not _do_args_match(model_output, parse_input): # type: ignore[arg-type] + if not _do_args_match(model_out, parse_in if has_shards else typing.Iterable[parse_in], 2): # type: ignore[arg-type,valid-type] warnings.warn( - f"Type returned from `model` (`{model_output}`) doesn't match type expected by " - f"`task.parse_responses()` (`{parse_input}`)." + f"Type returned from `model` (`{model_out}`) doesn't match type expected by " + f"`task.parse_responses()` (`{parse_in}`)." ) diff --git a/usage_examples/tests/test_readme_examples.py b/usage_examples/tests/test_readme_examples.py index ea386750..3246b797 100644 --- a/usage_examples/tests/test_readme_examples.py +++ b/usage_examples/tests/test_readme_examples.py @@ -165,12 +165,14 @@ def test_example_5_custom_model(): import random @registry.llm_models("RandomClassification.v1") - def random_textcat(labels: str) -> Callable[[Iterable[str]], Iterable[str]]: + def random_textcat( + labels: str, + ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: labels = labels.split(",") - def _classify(prompts: Iterable[str]) -> Iterable[str]: - for _ in prompts: - yield random.choice(labels) + def _classify(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + for prompts_for_doc in prompts: + yield [random.choice(labels) for _ in prompts_for_doc] return _classify @@ -197,5 +199,6 @@ def _classify(prompts: Iterable[str]) -> Iterable[str]: with open(tmpdir / "cfg", "w") as text_file: text_file.write(cfg_str) - nlp = assemble(tmpdir / "cfg") + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp = assemble(tmpdir / "cfg") nlp("i'd like a large margherita pizza please") diff --git a/usage_examples/tests/test_usage_examples.py b/usage_examples/tests/test_usage_examples.py index 1b3fa170..9a04e2bc 100644 --- a/usage_examples/tests/test_usage_examples.py +++ b/usage_examples/tests/test_usage_examples.py @@ -118,9 +118,10 @@ def test_ner_v3_openai(): @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_ner_langchain_openai(): """Test NER LangChain OpenAI usage example.""" - ner_langchain_openai.run_pipeline( - "text", _USAGE_EXAMPLE_PATH / "ner_langchain_openai" / "ner.cfg", False - ) + with pytest.warns(UserWarning, match="Task supports sharding"): + ner_langchain_openai.run_pipeline( + "text", _USAGE_EXAMPLE_PATH / "ner_langchain_openai" / "ner.cfg", False + ) @pytest.mark.external