From 93dbd4a9ec93e63581b6e1ccfc06c86b771ffa60 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 17 Dec 2024 15:50:10 -0500 Subject: [PATCH] fix and add rtest --- integration-tests/test_classifications.py | 28 ++++++++++++++++++++++- server/lorax_server/models/flash_bert.py | 2 ++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/integration-tests/test_classifications.py b/integration-tests/test_classifications.py index c597c0d0..4152be6b 100644 --- a/integration-tests/test_classifications.py +++ b/integration-tests/test_classifications.py @@ -16,7 +16,33 @@ def test_distilbert_ner(): with run_lorax_container(config): response = requests.post( "http://localhost:8080/classify", - json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."}, + json={ + "inputs": "Johnny supports the Golden State Warriors. He lives in London." + }, + ) + response.raise_for_status() + print("RESPONSE FROM CLASSIFICATION:", response.json()) + assert len(response.json()) > 0 + + +def test_bert_ner(): + config = { + "name": "bert-ner", + "model_id": "dslim/bert-base-NER", + "docker_args": { + "max_input_length": 512, + "max_batch_prefill_tokens": 512, + "max_batch_total_tokens": 512, + "max_total_tokens": 512, + "backend": "flashinfer", + }, + } + with run_lorax_container(config): + response = requests.post( + "http://localhost:8080/classify", + json={ + "inputs": "Johnny supports the Golden State Warriors. He lives in London." + }, ) response.raise_for_status() print("RESPONSE FROM CLASSIFICATION:", response.json()) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index e1aa6e30..a1fc67ba 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -179,6 +179,8 @@ def _forward_context( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, + # TODO: This is a hack to get the prefill state to work + window_left=0, ) def forward(self, batch: FlashEmbeddingClassificationBatch):