Skip to content

Commit

Permalink
fix and add rtest
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Dec 17, 2024
1 parent 6dfb215 commit 93dbd4a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
28 changes: 27 additions & 1 deletion integration-tests/test_classifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,33 @@ def test_distilbert_ner():
with run_lorax_container(config):
response = requests.post(
"http://localhost:8080/classify",
json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."},
json={
"inputs": "Johnny supports the Golden State Warriors. He lives in London."
},
)
response.raise_for_status()
print("RESPONSE FROM CLASSIFICATION:", response.json())
assert len(response.json()) > 0


def test_bert_ner():
config = {
"name": "bert-ner",
"model_id": "dslim/bert-base-NER",
"docker_args": {
"max_input_length": 512,
"max_batch_prefill_tokens": 512,
"max_batch_total_tokens": 512,
"max_total_tokens": 512,
"backend": "flashinfer",
},
}
with run_lorax_container(config):
response = requests.post(
"http://localhost:8080/classify",
json={
"inputs": "Johnny supports the Golden State Warriors. He lives in London."
},
)
response.raise_for_status()
print("RESPONSE FROM CLASSIFICATION:", response.json())
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def _forward_context(
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
# TODO: This is a hack to get the prefill state to work
window_left=0,
)

def forward(self, batch: FlashEmbeddingClassificationBatch):
Expand Down

0 comments on commit 93dbd4a

Please sign in to comment.