Skip to content

Commit

Permalink
Fixed batch merging and filtering to handle Outlines state (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Feb 20, 2024
1 parent 9432ec0 commit adbc1ae
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 37 deletions.
14 changes: 7 additions & 7 deletions integration-tests/scripts/dynamic_adapter_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def query_lorax(args):
"details": True,
}
if adapter_id is not None:
request_params["adapter_source"] = "local"
# request_params["adapter_source"] = "local"
request_params["adapter_id"] = adapter_id

print("request_params", request_params)
Expand Down Expand Up @@ -113,14 +113,14 @@ def main():
# ]

# Mistral
# prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]"
# adapters = [
# "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
# ]
prompt = "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]"
adapters = [
"vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
]

# GPT2
prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :"
adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"]
# prompt = "Brand Name : First Aid Beauty ; Product Name : Ultra Repair Cream Intense Hydration ; Review Title :"
# adapters = ["/data/adapters/9789adb7-cd03-4862-91d5-b41b6746682e_ludwig/model_weights"]

adapters += [None]
# adapters = [None]
Expand Down
9 changes: 7 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
read_offsets = []

next_token_chooser_parameters = []
sequence_processors = []
stopping_criterias = []

# Cumulative length
Expand Down Expand Up @@ -601,6 +602,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
read_offsets.extend(batch.read_offsets)

next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
if batch.next_token_chooser.schema_processor is not None:
sequence_processors.extend(batch.next_token_chooser.schema_processor.sequence_processors)
else:
# No sequence processors, so pad with Nones
sequence_processors.extend([None for _ in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)

# Update
Expand All @@ -614,9 +620,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
tokenizers=[],
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
sequence_processors=sequence_processors,
)
next_token_chooser.schema_processor = HeterogeneousSchemaLogitsProcessor.concatenate(
[b.next_token_chooser.schema_processor for b in batches])

adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

Expand Down
49 changes: 26 additions & 23 deletions server/lorax_server/utils/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,31 +419,20 @@ def filter(self, indices):


class HeterogeneousSchemaLogitsProcessor(LogitsProcessor):
r"""
"""
[`LogitsWarper`] for JSON schema enforcement.
This version uses Outlines to perform the constrained decoding.
Args:
schemas (`List[Optional[str]]`):
The JSON encoded schemas to enforce. `None` means no enforcement.
tokenizers (`List[Optional[PreTrainedTokenizerBase]]`):
The tokenizers to use for each request.
sequence_processors (`List[Optional[OutlinesLogitsProcessor]]`):
The Outlines processors to use for each request.
"""

def __init__(
self,
schemas: Optional[List[Optional[str]]] = None,
tokenizers: Optional[List[Optional[PreTrainedTokenizerBase]]] = None,
sequence_processors: List[Optional["OutlinesLogitsProcessor"]],
):
if schemas is None:
schemas = []
if tokenizers is None:
tokenizers = []

self.sequence_processors = [
OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None
for schema, tokenizer in zip(schemas, tokenizers)
]
self.sequence_processors = sequence_processors

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
for i, processor in enumerate(self.sequence_processors):
Expand All @@ -456,16 +445,30 @@ def filter(self, indices):
if any([x is not None for x in self.sequence_processors]):
return self
return None

@classmethod
def concatenate(
def from_schemas(
cls,
processors: List["HeterogeneousSchemaLogitsProcessor"]
schemas: List[Optional[str]],
tokenizers: List[Optional[PreTrainedTokenizerBase]],
) -> "HeterogeneousSchemaLogitsProcessor":
ret = HeterogeneousSchemaLogitsProcessor()
for p in processors:
ret.sequence_processors.extend(p.sequence_processors)
return ret
"""
Args:
schemas (`List[Optional[str]]`):
The JSON encoded schemas to enforce. `None` means no enforcement.
tokenizers (`List[Optional[PreTrainedTokenizerBase]]`):
The tokenizers to use for each request.
"""
if schemas is None:
schemas = []
if tokenizers is None:
tokenizers = []

sequence_processors = [
OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None
for schema, tokenizer in zip(schemas, tokenizers)
]
return cls(sequence_processors)


# Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
Expand Down
22 changes: 17 additions & 5 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
do_sample: List[bool],
seeds: List[int],
tokenizers: List[PreTrainedTokenizerBase],
sequence_processors: Optional[List[OutlinesLogitsProcessor]] = None,
):
warpers = []

Expand All @@ -274,11 +275,19 @@ def __init__(
else None
)

self.schema_processor = (
HeterogeneousSchemaLogitsProcessor(schemas, tokenizers)
if any(schemas)
else None
)
if sequence_processors is not None:
# Reuse the state from the previous generation steps
self.schema_processor = (
HeterogeneousSchemaLogitsProcessor(sequence_processors)
if any(sequence_processors)
else None
)
else:
self.schema_processor = (
HeterogeneousSchemaLogitsProcessor.from_schemas(schemas, tokenizers)
if any(schemas)
else None
)

if any([x != 1.0 for x in temperature]):
do_sample = [
Expand Down Expand Up @@ -384,6 +393,7 @@ def from_pb(
tokenizers: List[PreTrainedTokenizerBase],
dtype: torch.dtype,
device: torch.device,
sequence_processors: Optional[List[OutlinesLogitsProcessor]] = None,
) -> "HeterogeneousNextTokenChooser":
"""
Creates a `HeterogeneousNextTokenChooser` instance from the given protocol buffer.
Expand All @@ -393,6 +403,7 @@ def from_pb(
tokenizers (List[PreTrainedTokenizerBase]): The tokenizers to use for processing the tokens.
dtype (torch.dtype): The data type of the tokens.
device (torch.device): The device on which the tokens are processed.
sequence_processors (Optional[List[OutlinesLogitsProcessor]]): The sequence processors to use for processing the tokens.
Returns:
HeterogeneousNextTokenChooser: The created `HeterogeneousNextTokenChooser` instance.
Expand All @@ -408,6 +419,7 @@ def from_pb(
do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb],
tokenizers=tokenizers,
sequence_processors=sequence_processors,
device=device,
dtype=dtype,
)
Expand Down

0 comments on commit adbc1ae

Please sign in to comment.