Skip to content

Commit

Permalink
Merge pull request #928 from roboflow/feature/tests_for_llama
Browse files Browse the repository at this point in the history
Tests for LLama Vision 3.2
  • Loading branch information
PawelPeczek-Roboflow authored Jan 7, 2025
2 parents 6880a45 + b858114 commit fb52955
Show file tree
Hide file tree
Showing 12 changed files with 1,159 additions and 46 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test_package_install_inference.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test_package_install_inference_cli.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference-cli

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test_package_install_inference_gpu.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference-gpu

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference-gpu[extras]

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test_package_install_inference_sdk.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference-sdk

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Test package install - inference[extras]

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@
{RELEVANT_TASKS_DOCS_DESCRIPTION}
!!! warning "Issues with structured prompting"
Model tends to be quite unpredictable when structured output (in our case JSON document) is expected.
That problems may impact tasks like `structured-answering`, `classification` or `multi-label-classification`.
The cause seems to be quite sensitive "filters" of inappropriate content embedded in model.
#### 🛠️ API providers and model variants
Expand Down Expand Up @@ -219,19 +226,15 @@ class BlockManifest(WorkflowBlockManifest):
examples=["11B (Free) - OpenRouter", "$inputs.llama_model"],
)
max_tokens: int = Field(
default=300,
default=500,
description="Maximum number of tokens the model can generate in it's response.",
gt=1,
)
temperature: Optional[Union[float, Selector(kind=[FLOAT_KIND])]] = Field(
default=1,
temperature: Union[float, Selector(kind=[FLOAT_KIND])] = Field(
default=0.1,
description="Temperature to sample from the model - value in range 0.0-2.0, the higher - the more "
'random / "creative" the generations are.',
)
top_p: Optional[Union[float, Selector(kind=[FLOAT_KIND])]] = Field(
default=1.0,
description="Top-p to sample from the model - value in range 0.0-1.0, the higher - the more diverse and creative the generations are",
)
max_concurrent_requests: Optional[int] = Field(
default=None,
description="Number of concurrent requests that can be executed by block when batch of input images provided. "
Expand All @@ -258,23 +261,16 @@ def validate(self) -> "BlockManifest":
)
return self

@classmethod
@field_validator("temperature")
@classmethod
def validate_temperature(cls, value: Union[str, float]) -> Union[str, float]:
if isinstance(value, str):
return value
if value < 0.0 or value > 2.0:
raise ValueError(
"'temperature' parameter required to be in range [0.0, 2.0]"
)

@classmethod
@field_validator("top_p")
def validate_temperature(cls, value: Union[str, float]) -> Union[str, float]:
if isinstance(value, str):
return value
if value < 0.0 or value > 1.0:
raise ValueError("'top_p' parameter required to be in range [0.0, 2.0]")
return value

@classmethod
def get_parameters_accepting_batches(cls) -> List[str]:
Expand Down Expand Up @@ -325,7 +321,6 @@ def run(
model_version: ModelVersion,
max_tokens: int,
temperature: float,
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> BlockResult:
inference_images = [i.to_inference_format() for i in images]
Expand All @@ -339,7 +334,6 @@ def run(
llama_model_version=model_version,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
max_concurrent_requests=max_concurrent_requests,
)
return [
Expand All @@ -357,7 +351,6 @@ def run_llama_vision_32_llm_prompting(
llama_model_version: ModelVersion,
max_tokens: int,
temperature: float,
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> List[str]:
if task_type not in PROMPT_BUILDERS:
Expand Down Expand Up @@ -386,7 +379,6 @@ def run_llama_vision_32_llm_prompting(
model_version_id=model_version_id,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
max_concurrent_requests=max_concurrent_requests,
)

Expand All @@ -397,7 +389,6 @@ def execute_llama_vision_32_requests(
model_version_id: str,
max_tokens: int,
temperature: float,
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> List[str]:
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=llama_api_key)
Expand All @@ -409,7 +400,6 @@ def execute_llama_vision_32_requests(
llama_model_version=model_version_id,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
for prompt in llama_prompts
]
Expand All @@ -429,14 +419,12 @@ def execute_llama_vision_32_request(
llama_model_version: str,
max_tokens: int,
temperature: float,
top_p: Optional[float],
) -> str:
response = client.chat.completions.create(
model=llama_model_version,
messages=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
if response.choices is None:
error_detail = getattr(response, "error", {}).get("message", "N/A")
Expand Down
3 changes: 1 addition & 2 deletions inference/core/workflows/prototypes/block.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Union

from openai import BaseModel
from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field

from inference.core.workflows.errors import BlockInterfaceError
from inference.core.workflows.execution_engine.entities.base import OutputDefinition
Expand Down
18 changes: 14 additions & 4 deletions tests/inference/unit_tests/core/test_roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,7 +1724,9 @@ def test_get_workflow_specification_when_connection_error_occurs_but_file_is_cac
get_mock.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={"workflow": {"config": json.dumps({"specification": {"some": "some"}})}}
return_value={
"workflow": {"config": json.dumps({"specification": {"some": "some"}})}
}
),
)
_ = get_workflow_specification(
Expand All @@ -1744,7 +1746,10 @@ def test_get_workflow_specification_when_connection_error_occurs_but_file_is_cac
)

# then
assert result == {"some": "some", "id": None}, "Expected workflow specification to be retrieved from file"
assert result == {
"some": "some",
"id": None,
}, "Expected workflow specification to be retrieved from file"


@mock.patch.object(roboflow_api.requests, "get")
Expand All @@ -1760,7 +1765,9 @@ def test_get_workflow_specification_when_consecutive_request_hits_ephemeral_cach
get_mock.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={"workflow": {"config": json.dumps({"specification": {"some": "some"}})}}
return_value={
"workflow": {"config": json.dumps({"specification": {"some": "some"}})}
}
),
)
ephemeral_cache = MemoryCache()
Expand All @@ -1780,7 +1787,10 @@ def test_get_workflow_specification_when_consecutive_request_hits_ephemeral_cach
)

# then
assert result == {"some": "some", "id": None}, "Expected workflow specification to be retrieved from file"
assert result == {
"some": "some",
"id": None,
}, "Expected workflow specification to be retrieved from file"
assert get_mock.call_count == 1, "Expected remote API to be only called once"


Expand Down
Loading

0 comments on commit fb52955

Please sign in to comment.