diff --git a/docetl/api.py b/docetl/api.py index 5ec744aa..ad831195 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -134,6 +134,45 @@ def custom_parser(text: str) -> List[str]: This example shows a complete pipeline configuration with datasets, operations, steps, and output settings. """ + DEFAULT_RATE_LIMITS = { + # OpenAI models + "gpt-4o": 1000, + "gpt-4o-mini": 200, + "gpt-3.5-turbo": 500, + + # Anthropic models + "claude 3.5-sonnet": 1000, + "claude-3-opus": 500, + "claude-3-sonnet": 400, + "claude-3-haiku": 200, + } + + def get_rate_limits(self, model: str) -> dict: + """Get rate limits for a specific model. + + Args: + model: The model identifier (e.g., 'gpt-4o', 'claude-3-sonnet') + + Returns: + dict: Rate limit information including requests_per_minute + """ + if self.rate_limits and model in self.rate_limits: + return { + "requests_per_minute": self.rate_limits[model], + "source": "custom" + } + + if model in self.DEFAULT_RATE_LIMITS: + return { + "requests_per_minute": self.DEFAULT_RATE_LIMITS[model], + "source": "default" + } + + return { + "requests_per_minute": 200, + "source": "fallback" + } + def __init__( self, diff --git a/docetl/operations/hf_outlines.py b/docetl/operations/hf_outlines.py new file mode 100644 index 00000000..5eb49ba7 --- /dev/null +++ b/docetl/operations/hf_outlines.py @@ -0,0 +1,59 @@ +from typing import Any, Dict, List, Optional, Tuple +from pydantic import BaseModel, create_model +from docetl.operations.base import BaseOperation +from outlines import generate, models +import json + +class HuggingFaceMapOperation(BaseOperation): + class schema(BaseOperation.schema): + name: str + type: str = "hf_map" + model_path: str + output_schema: Dict[str, Any] + prompt_template: str + max_tokens: int = 4096 + + def __init__(self, config: Dict[str, Any], runner=None, *args, **kwargs): + super().__init__( + config=config, + default_model=config.get('default_model', config['model_path']), + max_threads=config.get('max_threads', 1), + runner=runner + ) + + self.model = models.transformers( + self.config["model_path"] + ) + + field_definitions = { + k: (eval(v) if isinstance(v, str) else v, ...) + for k, v in self.config["output_schema"].items() + } + output_model = create_model('OutputModel', **field_definitions) + + self.processor = generate.json( + self.model, + output_model + ) + + def syntax_check(self) -> None: + """Validate the operation configuration.""" + self.schema(**self.config) + + def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + """Process a single item through the model.""" + try: + result = self.processor(self.config["prompt_template"] + "\n" + str(item)) + result_dict = result.model_dump() + final_dict = {**item, **result_dict} + return final_dict + except Exception as e: + self.console.print(f"Error processing item: {e}") + return item + + @classmethod + def execute(cls, config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]: + """Execute the operation on the input data.""" + instance = cls(config) + results = [instance.process_item(item) for item in input_data] + return results, 0.0 \ No newline at end of file diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 88dfbbfd..a6368fe1 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -427,8 +427,11 @@ def merge_clusters(item1: int, item2: int) -> None: # Compute an auto-batch size based on the number of comparisons def auto_batch() -> int: - # Maximum batch size limit for 4o-mini model - M = 500 + # Get model-specific rate limit from pipeline + model = self.config.get("comparison_model", self.default_model) + rate_limit = self.runner.api.get_rate_limit(model) + # Use the rate limit as our maximum batch size + M = rate_limit["requests_per_minute"] n = len(input_data) m = len(blocked_pairs) @@ -450,6 +453,7 @@ def auto_batch() -> int: # Compare pairs and update clusters in real-time batch_size = self.config.get("compare_batch_size", auto_batch()) + rate_info = self.runner.pipeline.get_rate_limits(self.config.get("comparison_model", self.default_model)) self.console.log(f"Using compare batch size: {batch_size}") pair_costs = 0 diff --git a/tests/test_hf_outlines.py b/tests/test_hf_outlines.py new file mode 100644 index 00000000..ac247655 --- /dev/null +++ b/tests/test_hf_outlines.py @@ -0,0 +1,131 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from docetl.operations.hf_outlines import HuggingFaceMapOperation + +@pytest.fixture +def mock_runner(): + return Mock() + +@pytest.fixture +def sample_config(): + return { + "name": "test_hf_operation", + "type": "hf_map", + "model_path": "meta-llama/Llama-3.2-1B-Instruct", + "output_schema": { + "first_name": "str", + "last_name": "str" + }, + "prompt_template": "Extract customer information from this text", + "max_tokens": 4096 + } + +@pytest.fixture +def research_config(): + return { + "name": "research_analyzer", + "type": "hf_map", + "model_path": "meta-llama/Llama-3.2-1B-Instruct", + "output_schema": { + "title": "str", + "authors": "list", + "methodology": "str", + "findings": "list", + "limitations": "list", + "future_work": "list" + }, + "prompt_template": "Analyze the following research paper abstract.\nExtract key components and summarize findings.", + "max_tokens": 4096 + } + +@pytest.fixture +def mock_research_output(): + class MockOutput: + def model_dump(self): + return { + "title": "Deep Learning in Natural Language Processing", + "authors": ["John Smith", "Jane Doe"], + "methodology": "Comparative analysis of transformer architectures", + "findings": [ + "Improved accuracy by 15%", + "Reduced training time by 30%" + ], + "limitations": [ + "Limited dataset size", + "Computational constraints" + ], + "future_work": [ + "Extend to multilingual models", + "Optimize for edge devices" + ] + } + return MockOutput() + +def test_process_item(sample_config, mock_runner): + mock_model = MagicMock() + + class MockOutput: + def model_dump(self): + return { + "first_name": "John", + "last_name": "Doe" + } + + mock_processor = Mock(return_value=MockOutput()) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + operation = HuggingFaceMapOperation(sample_config, runner=mock_runner) + test_item = {"message": "test message"} + result = operation.process_item(test_item) + + assert isinstance(result, dict) + assert "first_name" in result + assert "last_name" in result + assert "message" in result + +def test_research_paper_analysis(research_config, mock_research_output, mock_runner): + mock_model = MagicMock() + mock_processor = Mock(return_value=mock_research_output) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + operation = HuggingFaceMapOperation(research_config, runner=mock_runner) + test_item = { + "abstract": """ + This paper presents a comprehensive analysis of deep learning approaches + in natural language processing. We compare various transformer architectures + and their performance on standard NLP tasks. + """ + } + result = operation.process_item(test_item) + + assert isinstance(result, dict) + assert "title" in result + assert isinstance(result["title"], str) + assert "authors" in result + assert isinstance(result["authors"], list) + assert "methodology" in result + assert isinstance(result["methodology"], str) + assert "findings" in result + assert isinstance(result["findings"], list) + assert len(result["findings"]) > 0 + assert "limitations" in result + assert isinstance(result["limitations"], list) + assert "future_work" in result + assert isinstance(result["future_work"], list) + assert "abstract" in result + +def test_execute(sample_config, mock_runner): + mock_model = MagicMock() + mock_processor = Mock(return_value={"first_name": "John", "last_name": "Doe"}) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + + input_data = [{"message": "test message"}] + results, timing = HuggingFaceMapOperation.execute(sample_config, input_data) + assert len(results) == 1 + assert isinstance(timing, float) \ No newline at end of file