From 43bfb50e67104ff2fe67eb7f4add5738a1ebdb69 Mon Sep 17 00:00:00 2001 From: robot-ci-heartex <87703623+robot-ci-heartex@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:51:57 +0100 Subject: [PATCH] fix: DIA-1508: Fix sdk to support Prompts MultiSkill (#233) Co-authored-by: nik Co-authored-by: hakan458 Co-authored-by: niklub Co-authored-by: niklub --- adala/skills/collection/entity_extraction.py | 103 +-- adala/skills/collection/label_studio.py | 23 +- adala/skills/collection/prompt_improvement.py | 9 +- poetry.lock | 6 +- pyproject.toml | 2 +- .../test_label_studio_skill_with_ner.yaml | 666 ++++++++++++++++++ tests/test_label_studio_skill.py | 85 +++ 7 files changed, 844 insertions(+), 50 deletions(-) create mode 100644 tests/cassettes/test_label_studio_skill/test_label_studio_skill_with_ner.yaml diff --git a/adala/skills/collection/entity_extraction.py b/adala/skills/collection/entity_extraction.py index 8fe6e682..202716dc 100644 --- a/adala/skills/collection/entity_extraction.py +++ b/adala/skills/collection/entity_extraction.py @@ -13,6 +13,67 @@ logger = logging.getLogger(__name__) +def extract_indices( + df, + input_field_name, + output_field_name, + quote_string_field_name='quote_string', + labels_field_name='label' + ): + """ + Give the input dataframe with "text" column and "entities" column of the format + ``` + [{"quote_string": "entity_1"}, {"quote_string": "entity_2"}, ...] + ``` + extract the indices of the entities in the input text and put indices in the "entities" column: + ``` + [{"quote_string": "entity_1", "start": 0, "end": 5}, {"quote_string": "entity_2", "start": 10, "end": 15}, ...] + ``` + """ + for i, row in df.iterrows(): + if row.get("_adala_error"): + logger.warning(f"Error in row {i}: {row['_adala_message']}") + continue + text = row[input_field_name] + entities = row[output_field_name] + to_remove = [] + found_entities_ends = {} + for entity in entities: + # TODO: current naive implementation uses exact string matching which can seem to be a baseline + # we can improve this further by handling ambiguity, for example: + # - requesting surrounding context from LLM + # - perform fuzzy matching over strings if model still hallucinates when copying the text + ent_str = entity[quote_string_field_name] + # to avoid overlapping entities, start from the end of the last entity with the same prefix + matching_end_indices = [ + found_entities_ends[found_ent] + for found_ent in found_entities_ends + if found_ent.startswith(ent_str) + ] + if matching_end_indices: + # start searching from the end of the last entity with the same prefix + start_search_idx = max(matching_end_indices) + else: + # start searching from the beginning + start_search_idx = 0 + + start_idx = text.lower().find( + entity[quote_string_field_name].lower(), + start_search_idx, + ) + if start_idx == -1: + # we need to remove the entity if it is not found in the text + to_remove.append(entity) + else: + end_index = start_idx + len(entity[quote_string_field_name]) + entity["start"] = start_idx + entity["end"] = end_index + found_entities_ends[ent_str] = end_index + for entity in to_remove: + entities.remove(entity) + return df + + def validate_schema(schema: Dict[str, Any]): expected_schema = { "type": "object", @@ -266,47 +327,7 @@ def extract_indices(self, df): """ input_field_name = self._get_input_field_name() output_field_name = self._get_output_field_name() - for i, row in df.iterrows(): - if row.get("_adala_error"): - logger.warning(f"Error in row {i}: {row['_adala_message']}") - continue - text = row[input_field_name] - entities = row[output_field_name] - to_remove = [] - found_entities_ends = {} - for entity in entities: - # TODO: current naive implementation uses exact string matching which can seem to be a baseline - # we can improve this further by handling ambiguity, for example: - # - requesting surrounding context from LLM - # - perform fuzzy matching over strings if model still hallucinates when copying the text - ent_str = entity[self._quote_string_field_name] - # to avoid overlapping entities, start from the end of the last entity with the same prefix - matching_end_indices = [ - found_entities_ends[found_ent] - for found_ent in found_entities_ends - if found_ent.startswith(ent_str) - ] - if matching_end_indices: - # start searching from the end of the last entity with the same prefix - start_search_idx = max(matching_end_indices) - else: - # start searching from the beginning - start_search_idx = 0 - - start_idx = text.lower().find( - entity[self._quote_string_field_name].lower(), - start_search_idx, - ) - if start_idx == -1: - # we need to remove the entity if it is not found in the text - to_remove.append(entity) - else: - end_index = start_idx + len(entity[self._quote_string_field_name]) - entity["start"] = start_idx - entity["end"] = end_index - found_entities_ends[ent_str] = end_index - for entity in to_remove: - entities.remove(entity) + df = extract_indices(df, input_field_name, output_field_name, self._quote_string_field_name, self._labels_field_name) return df def apply( diff --git a/adala/skills/collection/label_studio.py b/adala/skills/collection/label_studio.py index 46409748..20518511 100644 --- a/adala/skills/collection/label_studio.py +++ b/adala/skills/collection/label_studio.py @@ -1,5 +1,6 @@ import logging -from typing import Dict, Any, Type +import pandas as pd +from typing import Optional, Type from functools import cached_property from adala.skills._base import TransformSkill from pydantic import BaseModel, Field, model_validator @@ -8,8 +9,10 @@ from adala.utils.internal_data import InternalDataFrame from label_studio_sdk.label_interface import LabelInterface +from label_studio_sdk.label_interface.control_tags import ControlTag from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic +from .entity_extraction import extract_indices logger = logging.getLogger(__name__) @@ -24,7 +27,14 @@ class LabelStudioSkill(TransformSkill): # ------------------------------ label_config: str = "" - # TODO: implement postprocessing like in EntityExtractionSkill or to verify Taxonomy + # TODO: implement postprocessing to verify Taxonomy + + def has_ner_tag(self) -> Optional[ControlTag]: + # check if the input config has NER tag ( + ), and return its `from_name` and `to_name` + interface = LabelInterface(self.label_config) + for tag in interface.controls: + if tag.tag == 'Labels': + return tag @model_validator(mode='after') def validate_response_model(self): @@ -62,10 +72,17 @@ async def aapply( ) -> InternalDataFrame: with json_schema_to_pydantic(self.field_schema) as ResponseModel: - return await runtime.batch_to_batch( + output = await runtime.batch_to_batch( input, input_template=self.input_template, output_template="", instructions_template=self.instructions, response_model=ResponseModel, ) + ner_tag = self.has_ner_tag() + if ner_tag: + input_field_name = ner_tag.objects[0].value.lstrip('$') + output_field_name = ner_tag.name + quote_string_field_name = 'text' + output = extract_indices(pd.concat([input, output], axis=1), input_field_name, output_field_name, quote_string_field_name) + return output diff --git a/adala/skills/collection/prompt_improvement.py b/adala/skills/collection/prompt_improvement.py index e8fd3408..14744ad9 100644 --- a/adala/skills/collection/prompt_improvement.py +++ b/adala/skills/collection/prompt_improvement.py @@ -6,6 +6,7 @@ from adala.skills import AnalysisSkill from adala.utils.parse import parse_template from adala.utils.types import ErrorResponseModel +from adala.skills.collection.label_studio import LabelStudioSkill logger = logging.getLogger(__name__) @@ -75,8 +76,12 @@ class PromptImprovementSkill(AnalysisSkill): @model_validator(mode="after") def validate_prompts(self): + input_variables = "\n".join(self.input_variables) - + if isinstance(self.skill_to_improve, LabelStudioSkill): + model_json_schema = self.skill_to_improve.field_schema + else: + model_json_schema = self.skill_to_improve.response_model.model_json_schema() # rewrite the instructions with the actual values self.instructions = f"""\ You are a prompt engineer tasked with generating or enhancing a prompt for a Language Learning Model (LLM). Your goal is to create an effective prompt based on the given context, input data and requirements. @@ -96,7 +101,7 @@ def validate_prompts(self): ## Target response schema ```json -{json.dumps(self.skill_to_improve.response_model.model_json_schema(), indent=2)} +{json.dumps(model_json_schema, indent=2)} ``` Now, examine the current prompt (if provided): diff --git a/poetry.lock b/poetry.lock index afe218ca..356cde2f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3513,7 +3513,7 @@ description = "" optional = false python-versions = "^3.8" files = [ - {file = "e94072130d90f7b89701b9234ce175425f36f23e.zip", hash = "sha256:436834d571972327d9a24c96db19e0de8204c028f7cc01e5e36cece2a550ef91"}, + {file = "238b0f48dcac00d78a6e97862ea82011527307ce.zip", hash = "sha256:e54d85f7f9f4bd363df3485c070eab7c0e416da93859380d4e5125f50ffcb63f"}, ] [package.dependencies] @@ -3536,7 +3536,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/e94072130d90f7b89701b9234ce175425f36f23e.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/238b0f48dcac00d78a6e97862ea82011527307ce.zip" [[package]] name = "litellm" @@ -8080,4 +8080,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "a89c50ecdcd1b6da14dd1483dcd7b16e7e36355ef779eb07606e01041c1c67ee" +content-hash = "e180c394c89f439a13dbb216f373f6055360f98d65d5dbc91457792211ba33e8" diff --git a/pyproject.toml b/pyproject.toml index 9e7225c2..59619425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ celery = {version = "^5.3.6", extras = ["redis"]} kombu = ">=5.4.0rc2" # Pin version to fix https://github.com/celery/celery/issues/8030. TODO: remove when this fix will be included in celery uvicorn = "*" pydantic-settings = "^2.2.1" -label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/e94072130d90f7b89701b9234ce175425f36f23e.zip"} +label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/238b0f48dcac00d78a6e97862ea82011527307ce.zip"} kafka-python = "^2.0.2" # https://github.com/geerlingguy/ansible-role-docker/issues/462#issuecomment-2144121102 requests = "2.31.0" diff --git a/tests/cassettes/test_label_studio_skill/test_label_studio_skill_with_ner.yaml b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_with_ner.yaml new file mode 100644 index 00000000..ce243e41 --- /dev/null +++ b/tests/cassettes/test_label_studio_skill/test_label_studio_skill_with_ner.yaml @@ -0,0 +1,666 @@ +interactions: +- request: + body: '{"messages": [{"role": "user", "content": "Hey, how''s it going?"}], "model": + "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '142' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA2xRQW7bMBC86xVbXnKRCllWa9eXoCcnQYAeekiBohBoci2xprgEuYJrBAb6jX6v + LykoK7aL5EKAM5zhzO5zBiCMFisQqpOsem+Lz4/rOnzZc/1tp+3TQ/UVvarWHZuHZbcWeVLQ5icq + flG9V9R7i2zInWgVUDIm19mi+vRxXpYf5iPRk0abZK3noqaiN84UVVnVRbkoZstJ3ZFRGMUKvmcA + AM/jmXI6jb/ECsr8BekxRtmiWJ0fAYhANiFCxmgiS8civ5CKHKMbo9/f9KDJuBb2aG0O3Em3gwMN + 7+CO9iA3NHC63sJTJ/nv7z8RyCUgQG+cBiYtD7fX5gG3Q5SpoBusnfDjOa2l1gfaxIk/41vjTOya + gDKSS8kikxcje8wAfoxTGf4rKnyg3nPDtEOXDGf1yU5cdnFFLieSiaW94PMqf8Ot0cjS2Hg1VaGk + 6lBflGV2Ve31n29ZnOoZ175yySYnEQ+RsW+2xrUYfDCnPW19Uy8rVVVysVEiO2b/AAAA//8DAAeP + 2ya1AgAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57d8c81b5bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:34 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + path=/; expires=Tue, 22-Oct-24 21:17:34 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '393' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999793' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_23dac5aaf363e1f16a7bad078382a738 + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "Extract entities from the input + text:\nApple Inc. is an American multinational technology company that specializes + in consumer electronics, computer software, and online services."}], "model": + "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0, "tool_choice": + {"type": "function", "function": {"name": "MyModel"}}, "tools": [{"type": "function", + "function": {"name": "MyModel", "description": "Correctly extracted `MyModel` + with all the required parameters with correct types", "parameters": {"$defs": + {"Entity": {"properties": {"start": {"minimum": 0, "title": "Start", "type": + "integer"}, "end": {"minimum": 0, "title": "End", "type": "integer"}, "labels": + {"items": {"enum": ["Organization", "Product", "Version"], "type": "string"}, + "title": "Labels", "type": "array"}, "text": {"anyOf": [{"type": "string"}, + {"type": "null"}], "default": null, "title": "Text"}}, "required": ["start", + "end", "labels"], "title": "Entity", "type": "object"}}, "properties": {"entities": + {"description": "Labels and span indices for input", "items": {"$ref": "#/$defs/Entity"}, + "title": "Entities", "type": "array"}}, "required": ["entities"], "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1208' + content-type: + - application/json + cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA5RTTW+bQBC98ytWc4YIY2wn3JKqrqrWSlUpPTRYaL2MYZ39EruocS3/92rBMcTx + pRzQat6892Z2Zg8BIcBLyAiwmjomjYjuv39J7c8Fe2J2+1CW0qa/Pu13S/pokmULoWfozQ6Ze2Pd + MC2NQMe16mHWIHXoVSeL5G4+jeNZ2gFSlyg8rTIuSnUkueJREidpFC+iye2JXWvO0EJGngNCCDl0 + f1+nKvEVMhKHbxGJ1tIKITsnEQKNFj4C1FpuHVUOwgFkWjlUvnTVCjECnNaiYFSIwbj/DqPzcFlU + iKKyv1930+WPevdCxbcq/lMtH1by6fPIr5fem66gbavY+ZJG+DmeXZgRAorKjrvar7q7Cy8TaFO1 + EpXzZcMhB1SOO442h+z5kIN1tHE5ZHHooTKHbOKPgm5QdDk5PDYVVfwv9SXksA5zcPjqOTncGyOQ + fFXsJodjOJKbzs966ew/9CQ2nFFFZCscV10KFcQhq5UWutoTv0pU7XM4ro/wrtdjcO28Ho2wwW1r + qTjN9hQ/npdF6Mo0emMvZg9brritiwap7WYA1mnTe3ufzgHad3sGptHSuMLpF1RecDKf9XowvIUB + TdMT6LSjYogn8V14Ra8o0VHeLeJ59xllNZYDNQ5GzX00vSbRN8hV9UElOCmB3VuHsthyVWFjGt49 + FNiaIr1NWJLQxYZBcAz+AQAA//8DAMnKSm02BAAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57de78385bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:35 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '796' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999754' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8e0c4a010ac9b3100836e174b4b25017 + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "Extract entities from the input + text:\nThe iPhone 14 is the latest smartphone from Apple Inc."}], "model": "gpt-4o-mini", + "max_tokens": 200, "seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", + "function": {"name": "MyModel"}}, "tools": [{"type": "function", "function": + {"name": "MyModel", "description": "Correctly extracted `MyModel` with all the + required parameters with correct types", "parameters": {"$defs": {"Entity": + {"properties": {"start": {"minimum": 0, "title": "Start", "type": "integer"}, + "end": {"minimum": 0, "title": "End", "type": "integer"}, "labels": {"items": + {"enum": ["Organization", "Product", "Version"], "type": "string"}, "title": + "Labels", "type": "array"}, "text": {"anyOf": [{"type": "string"}, {"type": + "null"}], "default": null, "title": "Text"}}, "required": ["start", "end", "labels"], + "title": "Entity", "type": "object"}}, "properties": {"entities": {"description": + "Labels and span indices for input", "items": {"$ref": "#/$defs/Entity"}, "title": + "Entities", "type": "array"}}, "required": ["entities"], "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1122' + content-type: + - application/json + cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA2xTXY/aMBB8z6+w9jmpQghw5O16X2pVCqISrXpBkXGWxK1j+2KjwiH+e+XAkRxH + HixrZ2dmd73Ze4QAzyEhwEpqWaVFcPvtKbaLQfl9tnucvDzI+9X04fNi+1s/9UsE3zHU6g8y+8b6 + xFSlBVqu5BFmNVKLTrU3isbDfhgOBg1QqRyFoxXaBrEKKi55EIVRHISjoHdzYpeKMzSQkGePEEL2 + zenqlDluISGh/xap0BhaICTnJEKgVsJFgBrDjaXSgt+CTEmL0pUuN0J0AKuUyBgVojU+fvvOvR0W + FSJ7nM5+3m1/vFS/1N2/cN6/n0/mw6+46PgdpXe6KWi9kew8pA5+jicXZoSApFXDnewmzez8ywRa + F5sKpXVlwz4FlJZbjiaF5HmfgrG0tikkse+gPIWk566CrlA0OSnMapVvmE1h6adgcevSU+CzUkkk + vTiFg98R6g/PSvHgQmlaF1TyV+qauZC71Vog+SJZCoflAd51cfCu3Zedx6lxvTFUnF7tFD+c10Co + QtdqZS5eFdZcclNmNVLTTBeMVfro7XwaB9i82yDQtaq0zaz6i9IJ9gbDox60W96icf8EWmWp6LDG + Y/+KXpajpbxZsfNWM8pKzFtq6HWa+2h6TeLYIJfFBxXvpARmZyxW2ZrLAmtd8+YXgLXO4puIRREd + rRh4B+8/AAAA//8DALdtNXUQBAAA + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57e52b055bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:37 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '1921' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999776' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1263d19a86b1fa81ae41a69f13266bd2 + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "Extract entities from the input + text:\nThe MacBook Pro is a line of Macintosh portable computers introduced + in January 2006 by Apple Inc."}], "model": "gpt-4o-mini", "max_tokens": 200, + "seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", "function": + {"name": "MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel", + "description": "Correctly extracted `MyModel` with all the required parameters + with correct types", "parameters": {"$defs": {"Entity": {"properties": {"start": + {"minimum": 0, "title": "Start", "type": "integer"}, "end": {"minimum": 0, "title": + "End", "type": "integer"}, "labels": {"items": {"enum": ["Organization", "Product", + "Version"], "type": "string"}, "title": "Labels", "type": "array"}, "text": + {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": + "Text"}}, "required": ["start", "end", "labels"], "title": "Entity", "type": + "object"}}, "properties": {"entities": {"description": "Labels and span indices + for input", "items": {"$ref": "#/$defs/Entity"}, "title": "Entities", "type": + "array"}}, "required": ["entities"], "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1166' + content-type: + - application/json + cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA4yT30/bMBDH3/NXWPecojQtTckbm1i1iQKDaS8ERa5zST0cO7OvQKj6v09JSxIK + D8tDZN337nM/fN56jIHMIGYg1pxEWanR+eVi+nSxuKpvzU87//rjYvat1r/rYnY7qZ/AbyLM6g8K + eos6EaasFJI0ei8Li5ywoY6j8Gw2CYLTqBVKk6FqwoqKRlMzKqWWozAIp6MgGo3nh+i1kQIdxOze + Y4yxbftv6tQZvkDMAv/NUqJzvECIOyfGwBrVWIA7Jx1xTeD3ojCaUDel641SA4GMUangSvWJ9992 + cO6HxZVKL39dYf5a3l1f/V1MIjp7LjGaP+fBIN8eXVdtQflGi25IA72zx0fJGAPNyzZ2WS/b2fnH + DtwWmxI1NWXDNgHUJEmiSyC+3ybgiFtKIA78RsoSiMcTPwHFV6hanwRurMk2ghJ48BMgfGncE1hy + 8cWYR3ZjTQI7f4AKo441mR2xrm3BtXzlTTtHwPOqUsi+a3FyxBvPO14Y/m9tUpNx6wR2Dzt4N5Kd + 99n5YXDTFvON4+qwAgf7rtspZYrKmpU7WhHIpZZunVrkrr0qcGSqfe4mT5sBNu/WESpryopSMo+o + G+B4Nt3zoH8yvTobH0QyxFVvD8NT/xNemiFx2e5r90QEF2vM+tDAGzT3MelniH2DUhcfKN6BBK52 + hGWaS12graxs3xPkVTqdhyIMebQS4O28fwAAAP//AwBgGyygXQQAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57f24fc35bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:38 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '775' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999765' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_ae1c6d49effed66daf68d9369692edad + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "Extract entities from the input + text:\nThe Apple Watch is a line of smartwatches produced by Apple Inc."}], + "model": "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": 0.0, "tool_choice": + {"type": "function", "function": {"name": "MyModel"}}, "tools": [{"type": "function", + "function": {"name": "MyModel", "description": "Correctly extracted `MyModel` + with all the required parameters with correct types", "parameters": {"$defs": + {"Entity": {"properties": {"start": {"minimum": 0, "title": "Start", "type": + "integer"}, "end": {"minimum": 0, "title": "End", "type": "integer"}, "labels": + {"items": {"enum": ["Organization", "Product", "Version"], "type": "string"}, + "title": "Labels", "type": "array"}, "text": {"anyOf": [{"type": "string"}, + {"type": "null"}], "default": null, "title": "Text"}}, "required": ["start", + "end", "labels"], "title": "Entity", "type": "object"}}, "properties": {"entities": + {"description": "Labels and span indices for input", "items": {"$ref": "#/$defs/Entity"}, + "title": "Entities", "type": "array"}}, "required": ["entities"], "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1132' + content-type: + - application/json + cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA2xT22rjMBB991eIebaXxPHm4rdAl27ZhA0stNA6GEWeOMrKkrDG2WZD/r3YzsVN + 4wch5sycMzM6PniMgcwgZiA2nERhVTCdPUb/FvsB/ymySW6fBo/0vP0xmj9kD7MZ+HWFWW1R0Lnq + mzCFVUjS6BYWJXLCmrU/CifDQa/3fdwAhclQ1WW5pSAyQSG1DMJeGAW9UdAfn6o3Rgp0ELM3jzHG + Ds1Z96kzfIeY9fxzpEDneI4QX5IYg9KoOgLcOemIawL/CgqjCXXduq6U6gBkjEoFV+oq3H6Hzv26 + LK5UOsSpUIvqFfPnX4PX3Xa2oz/vs91LR6+l3tumoXWlxWVJHfwSj2/EGAPNi6Z2vp83u/NvE3iZ + VwVqqtuGQwKoSZJEl0D8dkjAES8pgTjyayhLIO4P/QQUX6FqchJYlCarBCWw9BMgfK/TE5haq5C9 + cBKbBI5+h2owvnBF4xuu32XOtfzP63HuEj5pkcBxeYRPcxy9e/dl53lKXFeOq9O7neLHixGUyW1p + Vu7mXWEttXSbtETumv2CI2Nb7VqnUYDqk4fAlqawlJL5i7om7Lfmbcxz9vkVjfonkAxx1amaTPw7 + fGmGxGVjsouvBRcbzK6lPa8z3FfRexTtgFLnX1i8ExO4vSMs0rXUOZa2lM1PAGubRuNQhCEfrQR4 + R+8DAAD//wMARbyGlRIEAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57f84ffc5bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:38 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '527' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999772' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_955848200432f32d6fb6799b7365bace + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "user", "content": "Extract entities from the input + text:\nThe iPad is a line of tablet computers designed, developed, and marketed + by Apple Inc."}], "model": "gpt-4o-mini", "max_tokens": 200, "seed": 47, "temperature": + 0.0, "tool_choice": {"type": "function", "function": {"name": "MyModel"}}, "tools": + [{"type": "function", "function": {"name": "MyModel", "description": "Correctly + extracted `MyModel` with all the required parameters with correct types", "parameters": + {"$defs": {"Entity": {"properties": {"start": {"minimum": 0, "title": "Start", + "type": "integer"}, "end": {"minimum": 0, "title": "End", "type": "integer"}, + "labels": {"items": {"enum": ["Organization", "Product", "Version"], "type": + "string"}, "title": "Labels", "type": "array"}, "text": {"anyOf": [{"type": + "string"}, {"type": "null"}], "default": null, "title": "Text"}}, "required": + ["start", "end", "labels"], "title": "Entity", "type": "object"}}, "properties": + {"entities": {"description": "Labels and span indices for input", "items": {"$ref": + "#/$defs/Entity"}, "title": "Entities", "type": "array"}}, "required": ["entities"], + "type": "object"}}}]}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1154' + content-type: + - application/json + cookie: + - __cf_bm=NYYRlwB5fds4iBPc4Q7glBYq5O41iwyboL6BtnPhi2I-1729630054-1.0.1.1-i7Y12A0shO9XNCvr9.h_3hhfW_PNBbY42GhGyW6.qQvrtZJ36Uf5cCrLtcTZMcAUbqNT.RM7d3J0v76W4L8mjA; + _cfuvid=JXZHD6AMMSSJdeK52vhqi669crBeQ2UYSLdigLtH1Xs-1729630054009-0.0.1.1-604800000 + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.47.1 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.47.1 + x-stainless-raw-response: + - 'true' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.11.9 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA2xSTW+jMBC98yusOZMVISS03HJp1dVm011VWzV1hIxxiLvG9mIjJY347ytDCjQN + B2TNm3lvPt7JQwh4DgkCuieWllpMlj/uo8O6vn97WR6foix+2UQPf+JgHn3Pfj+D7ypU9sao/aj6 + RlWpBbNcyQ6mFSOWOdZpHN4uZkEwv22BUuVMuLJC20mkJiWXfBIGYTQJ4sn05ly9V5wyAwl69RBC + 6NT+XZ8yZwdIUOB/REpmDCkYJH0SQlAp4SJAjOHGEmnBH0CqpGXStS5rIUaAVUqklAgxCHffafQe + lkWESMv4OYqXv+Tx592ivJuu1pvi/d+T3oz0Ouqjbhva1ZL2SxrhfTy5EEMIJCnb2tVx1e7Ov0wg + VVGXTFrXNpwwMGm55cxgSF5PGIwllcWQhLHvsBxDMgt9DIJkTLRJGB4rldfUYtj6GCw7uHwM/JHk + GBp/RLJY9CTx/IJkXRVE8nfiBrlgWmotGHqQFEOzbeDTBI137b0dHaZiu9oQcb7YOd70FhCq0JXK + zMVFYcclN/u0YsS0mwVjle60nU6rAPUn94CuVKltatVfJh3hdBF2fDA4fECj6Rm0yhIxxMNg5l/h + S3NmCW/t1TuaErpn+VAaeKPhvopeo+gG5LL4wuKdmcAcjWVluuOyYJWueGt/2Ok0uglpGJI4o+A1 + 3n8AAAD//wMAGkwFhgwEAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8d6c57fcdf625bee-LIS + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 22 Oct 2024 20:47:39 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - heartex + openai-processing-ms: + - '559' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999768' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_bb013ee7333ce30ef444c2c30b6761ed + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_label_studio_skill.py b/tests/test_label_studio_skill.py index 2842f3c1..b608b45a 100644 --- a/tests/test_label_studio_skill.py +++ b/tests/test_label_studio_skill.py @@ -66,3 +66,88 @@ async def test_label_studio_skill(): "The issue clearly indicates a problem with the login functionality of the platform, which is a critical feature. Users are unable to access their accounts, suggesting a potential bug that needs to be addressed.", "The issue is requesting the addition of support for a new file type (.docx), which indicates a desire for new functionality in the system. This aligns with the definition of a feature request, as it seeks to enhance the capabilities of the application." ] + + +@pytest.mark.asyncio +@pytest.mark.vcr +async def test_label_studio_skill_with_ner(): + # documents that contain entities + df = pd.DataFrame( + [ + { + "text": "Apple Inc. is an American multinational technology company that specializes in consumer electronics, computer software, and online services." + }, + {"text": "The iPhone 14 is the latest smartphone from Apple Inc."}, + { + "text": "The MacBook Pro is a line of Macintosh portable computers introduced in January 2006 by Apple Inc." + }, + { + "text": "The Apple Watch is a line of smartwatches produced by Apple Inc." + }, + { + "text": "The iPad is a line of tablet computers designed, developed, and marketed by Apple Inc." + }, + ] + ) + + agent_payload = { + "runtimes": { + "default": { + "type": "AsyncLiteLLMChatRuntime", + "model": "gpt-4o-mini", + "api_key": os.getenv("OPENAI_API_KEY"), + "max_tokens": 200, + "temperature": 0, + "batch_size": 100, + "timeout": 10, + "verbose": False, + } + }, + "skills": [ + { + "type": "LabelStudioSkill", + "name": "AnnotationResult", + "input_template": 'Extract entities from the input text:\n{text}', + "label_config": """ + + + + + + """ + } + ], + } + + agent = Agent(**agent_payload) + predictions = await agent.arun(df) + + expected_predictions = [ + [ + {'start': 0, 'end': 10, 'labels': ['Organization'], 'text': 'Apple Inc.'}, + {'start': 17, 'end': 58, 'labels': ['Organization'], 'text': 'American multinational technology company'} + ], + [ + {'start': 4, 'end': 13, 'labels': ['Product'], 'text': 'iPhone 14'}, + {'start': 44, 'end': 53, 'labels': ['Organization'], 'text': 'Apple Inc'} + ], + [ + {'start': 4, 'end': 15, 'labels': ['Product'], 'text': 'MacBook Pro'}, + {'start': 88, 'end': 98, 'labels': ['Organization'], 'text': 'Apple Inc.'}, + {'start': 29, 'end': 38, 'labels': ['Product'], 'text': 'Macintosh'} + ], + [ + {'start': 4, 'end': 15, 'labels': ['Product'], 'text': 'Apple Watch'}, + {'start': 54, 'end': 63, 'labels': ['Organization'], 'text': 'Apple Inc'} + ], + [ + {'start': 4, 'end': 8, 'labels': ['Product'], 'text': 'iPad'}, + {'start': 76, 'end': 85, 'labels': ['Organization'], 'text': 'Apple Inc'} + ] + ] + + assert predictions.entities.tolist() == expected_predictions + \ No newline at end of file