Skip to content

Commit

Permalink
feat: DIA-1384: add cost estimate endpoint (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
pakelley authored Oct 23, 2024
1 parent db34d51 commit cbc539c
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 10 deletions.
7 changes: 4 additions & 3 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from adala.utils.internal_data import InternalDataFrame
from adala.utils.types import BatchData

logger = logging.getLogger(__name__)


Expand All @@ -40,7 +41,7 @@ class Agent(BaseModel, ABC):
Attributes:
environment (Environment): The environment with which the agent interacts.
skills (Union[SkillSet, List[Skill]]): The skills possessed by the agent.
skills (SkillSet): The skills possessed by the agent.
memory (LongTermMemory, optional): The agent's long-term memory. Defaults to None.
runtimes (Dict[str, Runtime], optional): The runtimes available to the agent. Defaults to predefined runtimes.
default_runtime (str): The default runtime used by the agent. Defaults to 'openai'.
Expand All @@ -57,7 +58,7 @@ class Agent(BaseModel, ABC):
"""

environment: Optional[SerializeAsAny[Union[Environment, AsyncEnvironment]]] = None
skills: SerializeAsAny[Union[Skill, SkillSet]]
skills: SerializeAsAny[SkillSet]

memory: Memory = Field(default=None)
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
Expand Down Expand Up @@ -418,7 +419,7 @@ async def arefine_skill(
skill = self.skills[skill_name]
if not isinstance(skill, TransformSkill):
raise ValueError(f"Skill {skill_name} is not a TransformSkill")

# get default runtimes
runtime = self.get_runtime()
teacher_runtime = self.get_teacher_runtime()
Expand Down
70 changes: 68 additions & 2 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import instructor
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
import traceback
from adala.runtimes.base import CostEstimate
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.internal_data import InternalDataFrame
from adala.utils.parse import (
Expand Down Expand Up @@ -122,7 +123,6 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:


class InstructorClientMixin:

def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.completion, **kwargs)

Expand All @@ -139,7 +139,6 @@ def is_custom_openai_endpoint(self) -> bool:


class InstructorAsyncClientMixin(InstructorClientMixin):

def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.acompletion, **kwargs)

Expand Down Expand Up @@ -527,6 +526,73 @@ async def record_to_record(
# Extract the single row from the output DataFrame and convert it to a dictionary
return output_df.iloc[0].to_dict()

@staticmethod
def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
max_tokens = litellm.get_model_info(
model=model, custom_llm_provider="openai"
).get("max_tokens", None)
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
n_outputs = len(output_fields) if output_fields else 1
return min(max_tokens, 4 * n_outputs)

@classmethod
def _estimate_cost(
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields)
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
try:
user_prompts = [
prompt.format(**substitution) for substitution in substitutions
]
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = self._estimate_cost(
user_prompt=user_prompt,
model=self.model,
output_fields=output_fields,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
cumulative_total_cost += total_cost
return CostEstimate(
prompt_cost_usd=cumulative_prompt_cost,
completion_cost_usd=cumulative_completion_cost,
total_cost_usd=cumulative_total_cost,
)

except Exception as e:
logger.error("Failed to estimate cost: %s", e)
return CostEstimate(
is_error=True,
error_type=type(e).__name__,
error_message=str(e),
)


class LiteLLMVisionRuntime(LiteLLMChatRuntime):
"""
Expand Down
49 changes: 44 additions & 5 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Type

from tqdm import tqdm
from abc import ABC, abstractmethod
from pydantic import BaseModel, model_validator, Field
from typing import List, Dict, Optional, Tuple, Any, Callable, ClassVar, Type
from adala.utils.internal_data import InternalDataFrame, InternalSeries
from adala.utils.internal_data import InternalDataFrame
from adala.utils.registry import BaseModelInRegistry
from pandarallel import pandarallel
from pydantic import BaseModel, Field, model_validator
from tqdm import tqdm

logger = logging.getLogger(__name__)
tqdm.pandas()


class CostEstimate(BaseModel):
prompt_cost_usd: Optional[float] = None
completion_cost_usd: Optional[float] = None
total_cost_usd: Optional[float] = None
is_error: bool = False
error_type: Optional[str] = None
error_message: Optional[str] = None

def __add__(self, other: "CostEstimate") -> "CostEstimate":
# if either has an error, it takes precedence
if self.is_error:
return self
if other.is_error:
return other

def _safe_add(lhs: Optional[float], rhs: Optional[float]) -> Optional[float]:
if lhs is None and rhs is None:
return None
_lhs = lhs or 0.0
_rhs = rhs or 0.0
return _lhs + _rhs

prompt_cost_usd = _safe_add(self.prompt_cost_usd, other.prompt_cost_usd)
completion_cost_usd = _safe_add(
self.completion_cost_usd, other.completion_cost_usd
)
total_cost_usd = _safe_add(self.total_cost_usd, other.total_cost_usd)
return CostEstimate(
prompt_cost_usd=prompt_cost_usd,
completion_cost_usd=completion_cost_usd,
total_cost_usd=total_cost_usd,
)


class Runtime(BaseModelInRegistry):
"""
Base class representing a generic runtime environment.
Expand Down Expand Up @@ -191,6 +225,11 @@ def record_to_batch(
response_model=response_model,
)

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
raise NotImplementedError("This runtime does not support cost estimates")


class AsyncRuntime(Runtime):
"""Async version of runtime that uses asyncio to process batch of records."""
Expand Down
60 changes: 60 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
Expand All @@ -24,6 +25,7 @@
from server.handlers.result_handlers import ResultHandler
from server.log_middleware import LogMiddleware
from adala.skills.collection.prompt_improvement import ImprovedPromptResponse
from adala.runtimes.base import CostEstimate
from server.tasks.stream_inference import streaming_parent_task
from server.utils import (
Settings,
Expand Down Expand Up @@ -81,6 +83,12 @@ class BatchSubmitted(BaseModel):
job_id: str


class CostEstimateRequest(BaseModel):
agent: Agent
prompt: str
substitutions: List[Dict]


class Status(Enum):
PENDING = "Pending"
INPROGRESS = "InProgress"
Expand Down Expand Up @@ -210,6 +218,58 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
):
"""
Estimates what it would cost to run inference on the batch of data in
`request` (using the run params from `request`)
Args:
request (CostEstimateRequest): Specification for the inference run to
make an estimate for, includes:
agent (adala.agent.Agent): The agent definition, used to get the model
and any other params necessary to estimate cost
prompt (str): The prompt template that will be used for each task
substitutions (List[Dict]): Mappings to substitute (simply using str.format)
Returns:
Response[CostEstimate]: The cost estimate, including the prompt/completion/total costs (in USD)
"""
prompt = request.prompt
substitutions = request.substitutions
agent = request.agent
runtime = agent.get_runtime()

try:
cost_estimates = []
for skill in agent.skills.skills.values():
output_fields = (
list(skill.field_schema.keys()) if skill.field_schema else None
)
cost_estimate = runtime.get_cost_estimate(
prompt=prompt, substitutions=substitutions, output_fields=output_fields
)
cost_estimates.append(cost_estimate)
total_cost_estimate = sum(
cost_estimates,
CostEstimate(
prompt_cost_usd=None, completion_cost_usd=None, total_cost_usd=None
),
)

except NotImplementedError as e:
return Response[CostEstimate](
data=CostEstimate(
is_error=True,
error_type=type(e).__name__,
error_message=str(e),
)
)
return Response[CostEstimate](data=total_cost_estimate)


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand Down
80 changes: 80 additions & 0 deletions tests/test_cost_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
import pytest
from adala.runtimes._litellm import AsyncLiteLLMChatRuntime
from adala.runtimes.base import CostEstimate
from adala.agents import Agent
from adala.skills import ClassificationSkill
import numpy as np
import os
from fastapi.testclient import TestClient
from server.app import app, CostEstimateRequest

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


@pytest.mark.use_openai
def test_simple_estimate_cost():
runtime = AsyncLiteLLMChatRuntime(model="gpt-4o-mini", api_key=OPENAI_API_KEY)

cost_estimate = runtime.get_cost_estimate(
prompt="testing, {text}",
substitutions=[{"text": "knock knock, who's there"}],
output_fields=["text"],
)

assert isinstance(cost_estimate, CostEstimate)
assert isinstance(cost_estimate.prompt_cost_usd, float)
assert isinstance(cost_estimate.completion_cost_usd, float)
assert isinstance(cost_estimate.total_cost_usd, float)
assert np.isclose(
cost_estimate.total_cost_usd,
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
)


@pytest.mark.use_openai
def test_estimate_cost_endpoint(client):
req = {
"agent": {
"skills": [
{
"type": "ClassificationSkill",
"name": "text_classifier",
"instructions": "Always return the answer 'Feature Lack'.",
"input_template": "{text}",
"output_template": "{output}",
"labels": [
"Feature Lack",
"Price",
"Integration Issues",
"Usability Concerns",
"Competitor Advantage",
],
}
],
"runtimes": {
"default": {
"type": "AsyncLiteLLMChatRuntime",
"model": "gpt-4o-mini",
"api_key": OPENAI_API_KEY,
}
},
},
"prompt": "test {text}",
"substitutions": [{"text": "test"}],
}
resp = client.post(
"/estimate-cost",
json=req,
)
resp_data = resp.json()["data"]
cost_estimate = CostEstimate(**resp_data)

assert isinstance(cost_estimate, CostEstimate)
assert isinstance(cost_estimate.prompt_cost_usd, float)
assert isinstance(cost_estimate.completion_cost_usd, float)
assert isinstance(cost_estimate.total_cost_usd, float)
assert np.isclose(
cost_estimate.total_cost_usd,
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
)

0 comments on commit cbc539c

Please sign in to comment.