diff --git a/integration-tests/scripts/dynamic_adapter_loading.py b/integration-tests/scripts/dynamic_adapter_loading.py index 6ec3f43bc..b2502953a 100644 --- a/integration-tests/scripts/dynamic_adapter_loading.py +++ b/integration-tests/scripts/dynamic_adapter_loading.py @@ -52,7 +52,7 @@ def query_lorax(args): "details": True, } if adapter_id is not None: - request_params["adapter_source"] = "local" + # request_params["adapter_source"] = "local" request_params["adapter_id"] = adapter_id print("request_params", request_params) @@ -113,14 +113,14 @@ def main(): # ] # Mistral - # prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" - # adapters = [ - # "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", - # ] + prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" + adapters = [ + "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", + ] # GPT2 - prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :" - adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"] + # prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :" + # adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"] adapters += [None] # adapters = [None] diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index a93c2db3b..432785c65 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -544,6 +544,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch read_offsets = [] next_token_chooser_parameters = [] + sequence_processors = [] stopping_criterias = [] # Cumulative length @@ -601,6 +602,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) + if batch.next_token_chooser.schema_processor is not None: + sequence_processors.extend(batch.next_token_chooser.schema_processor.sequence_processors) + else: + # No sequence processors, so pad with Nones + sequence_processors.extend([None for _ in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) # Update @@ -614,9 +620,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch tokenizers=[], dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + sequence_processors=sequence_processors, ) - next_token_chooser.schema_processor = HeterogeneousSchemaLogitsProcessor.concatenate( - [b.next_token_chooser.schema_processor for b in batches]) adapter_segments, adapter_segment_indices = adapter_segment_builder.build() diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 24a6ee0cd..c5f58a274 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -419,31 +419,20 @@ def filter(self, indices): class HeterogeneousSchemaLogitsProcessor(LogitsProcessor): - r""" + """ [`LogitsWarper`] for JSON schema enforcement. This version uses Outlines to perform the constrained decoding. Args: - schemas (`List[Optional[str]]`): - The JSON encoded schemas to enforce. `None` means no enforcement. - tokenizers (`List[Optional[PreTrainedTokenizerBase]]`): - The tokenizers to use for each request. + sequence_processors (`List[Optional[OutlinesLogitsProcessor]]`): + The Outlines processors to use for each request. """ def __init__( self, - schemas: Optional[List[Optional[str]]] = None, - tokenizers: Optional[List[Optional[PreTrainedTokenizerBase]]] = None, + sequence_processors: List[Optional["OutlinesLogitsProcessor"]], ): - if schemas is None: - schemas = [] - if tokenizers is None: - tokenizers = [] - - self.sequence_processors = [ - OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None - for schema, tokenizer in zip(schemas, tokenizers) - ] + self.sequence_processors = sequence_processors def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in enumerate(self.sequence_processors): @@ -456,16 +445,30 @@ def filter(self, indices): if any([x is not None for x in self.sequence_processors]): return self return None - + @classmethod - def concatenate( + def from_schemas( cls, - processors: List["HeterogeneousSchemaLogitsProcessor"] + schemas: List[Optional[str]], + tokenizers: List[Optional[PreTrainedTokenizerBase]], ) -> "HeterogeneousSchemaLogitsProcessor": - ret = HeterogeneousSchemaLogitsProcessor() - for p in processors: - ret.sequence_processors.extend(p.sequence_processors) - return ret + """ + Args: + schemas (`List[Optional[str]]`): + The JSON encoded schemas to enforce. `None` means no enforcement. + tokenizers (`List[Optional[PreTrainedTokenizerBase]]`): + The tokenizers to use for each request. + """ + if schemas is None: + schemas = [] + if tokenizers is None: + tokenizers = [] + + sequence_processors = [ + OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None + for schema, tokenizer in zip(schemas, tokenizers) + ] + return cls(sequence_processors) # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 7a8f53487..6d37bd3d8 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -251,6 +251,7 @@ def __init__( do_sample: List[bool], seeds: List[int], tokenizers: List[PreTrainedTokenizerBase], + sequence_processors: Optional[List[OutlinesLogitsProcessor]] = None, ): warpers = [] @@ -274,11 +275,19 @@ def __init__( else None ) - self.schema_processor = ( - HeterogeneousSchemaLogitsProcessor(schemas, tokenizers) - if any(schemas) - else None - ) + if sequence_processors is not None: + # Reuse the state from the previous generation steps + self.schema_processor = ( + HeterogeneousSchemaLogitsProcessor(sequence_processors) + if any(sequence_processors) + else None + ) + else: + self.schema_processor = ( + HeterogeneousSchemaLogitsProcessor.from_schemas(schemas, tokenizers) + if any(schemas) + else None + ) if any([x != 1.0 for x in temperature]): do_sample = [ @@ -384,6 +393,7 @@ def from_pb( tokenizers: List[PreTrainedTokenizerBase], dtype: torch.dtype, device: torch.device, + sequence_processors: Optional[List[OutlinesLogitsProcessor]] = None, ) -> "HeterogeneousNextTokenChooser": """ Creates a `HeterogeneousNextTokenChooser` instance from the given protocol buffer. @@ -393,6 +403,7 @@ def from_pb( tokenizers (List[PreTrainedTokenizerBase]): The tokenizers to use for processing the tokens. dtype (torch.dtype): The data type of the tokens. device (torch.device): The device on which the tokens are processed. + sequence_processors (Optional[List[OutlinesLogitsProcessor]]): The sequence processors to use for processing the tokens. Returns: HeterogeneousNextTokenChooser: The created `HeterogeneousNextTokenChooser` instance. @@ -408,6 +419,7 @@ def from_pb( do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], tokenizers=tokenizers, + sequence_processors=sequence_processors, device=device, dtype=dtype, )