diff --git a/cookbooks/0-rhubarb-cookbook.ipynb b/cookbooks/0-rhubarb-cookbook.ipynb index 2573caf..7cd719b 100644 --- a/cookbooks/0-rhubarb-cookbook.ipynb +++ b/cookbooks/0-rhubarb-cookbook.ipynb @@ -154,7 +154,7 @@ "source": [ "### Default Model\n", "---\n", - "By default Rhubarb uses Claude Sonnet model, however you can also use Haiku or Opus (when available)." + "By default Rhubarb uses Claude Sonnet model, however you can also use Haiku, Sonnet 3.5 or Opus (when available)." ] }, { diff --git a/sample_deployments/cdk_lambda/app.py b/sample_deployments/cdk_lambda/app.py index 5d067fd..178a73b 100644 --- a/sample_deployments/cdk_lambda/app.py +++ b/sample_deployments/cdk_lambda/app.py @@ -1,27 +1,21 @@ #!/usr/bin/env python3 -import os import aws_cdk as cdk - from infra.lambda_document_proccessing import SampleDeploymentsStack - app = cdk.App() -SampleDeploymentsStack(app, "SampleDeploymentsStack", +SampleDeploymentsStack( + app, + "SampleDeploymentsStack", # If you don't specify 'env', this stack will be environment-agnostic. # Account/Region-dependent features and context lookups will not work, # but a single synthesized template can be deployed anywhere. - # Uncomment the next line to specialize this stack for the AWS Account # and Region that are implied by the current CLI configuration. - - #env=cdk.Environment(account=os.getenv('CDK_DEFAULT_ACCOUNT'), region=os.getenv('CDK_DEFAULT_REGION')), - + # env=cdk.Environment(account=os.getenv('CDK_DEFAULT_ACCOUNT'), region=os.getenv('CDK_DEFAULT_REGION')), # Uncomment the next line if you know exactly what Account and Region you # want to deploy the stack to. */ - - #env=cdk.Environment(account='123456789012', region='us-east-1'), - + # env=cdk.Environment(account='123456789012', region='us-east-1'), # For more information, see https://docs.aws.amazon.com/cdk/latest/guide/environments.html ) diff --git a/sample_deployments/cdk_lambda/infra/lambda_document_proccessing.py b/sample_deployments/cdk_lambda/infra/lambda_document_proccessing.py index b6521fb..741b5be 100644 --- a/sample_deployments/cdk_lambda/infra/lambda_document_proccessing.py +++ b/sample_deployments/cdk_lambda/infra/lambda_document_proccessing.py @@ -1,58 +1,55 @@ -from aws_cdk import ( - Duration, - RemovalPolicy, - Stack, - aws_lambda, - aws_iam, - aws_s3 -) +from pathlib import Path + +from aws_cdk import Stack, Duration, RemovalPolicy, aws_s3, aws_iam, aws_lambda from constructs import Construct -from pathlib import Path -class SampleDeploymentsStack(Stack): +class SampleDeploymentsStack(Stack): def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: super().__init__(scope, construct_id, **kwargs) - bucket = aws_s3.Bucket(self, "MyLambdaBucket", - removal_policy=RemovalPolicy.DESTROY, - auto_delete_objects=True + bucket = aws_s3.Bucket( + self, "MyLambdaBucket", removal_policy=RemovalPolicy.DESTROY, auto_delete_objects=True ) - - lambda_function = self.__create_lambda_function("document-processing-lambda", bucket) + lambda_function = self.__create_lambda_function("document-processing-lambda", bucket) srole_bedrock = aws_iam.Role( - scope = self, - id = f'srole-bedrock', - assumed_by = aws_iam.ServicePrincipal('bedrock.amazonaws.com') + scope=self, + id="srole-bedrock", + assumed_by=aws_iam.ServicePrincipal("bedrock.amazonaws.com"), ) srole_bedrock.grant_pass_role(lambda_function) - + lambda_function.role.add_managed_policy( - aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonBedrockFullAccess') + aws_iam.ManagedPolicy.from_aws_managed_policy_name("AmazonBedrockFullAccess") ) bucket.grant_read_write(lambda_function.role) def __create_lambda_function(self, function_name: str, bucket: aws_s3.Bucket): - - lambda_role = aws_iam.Role(self, 'LambdaExecutionRole', - assumed_by=aws_iam.ServicePrincipal('lambda.amazonaws.com'), - managed_policies=[aws_iam.ManagedPolicy.from_aws_managed_policy_name('service-role/AWSLambdaBasicExecutionRole')] + lambda_role = aws_iam.Role( + self, + "LambdaExecutionRole", + assumed_by=aws_iam.ServicePrincipal("lambda.amazonaws.com"), + managed_policies=[ + aws_iam.ManagedPolicy.from_aws_managed_policy_name( + "service-role/AWSLambdaBasicExecutionRole" + ) + ], ) - + lambda_function = aws_lambda.DockerImageFunction( - scope=self, - id=function_name, - function_name=function_name, - code=aws_lambda.DockerImageCode.from_image_asset(directory=f"{Path('source/lambda').absolute()}"), - timeout=Duration.minutes(15), - memory_size=3000, - role=lambda_role, - environment={ - 'BUCKET_NAME': bucket.bucket_name - } + scope=self, + id=function_name, + function_name=function_name, + code=aws_lambda.DockerImageCode.from_image_asset( + directory=f"{Path('source/lambda').absolute()}" + ), + timeout=Duration.minutes(15), + memory_size=3000, + role=lambda_role, + environment={"BUCKET_NAME": bucket.bucket_name}, ) return lambda_function diff --git a/sample_deployments/cdk_lambda/source/lambda/handler.py b/sample_deployments/cdk_lambda/source/lambda/handler.py index 69e0005..e128e90 100644 --- a/sample_deployments/cdk_lambda/source/lambda/handler.py +++ b/sample_deployments/cdk_lambda/source/lambda/handler.py @@ -1,36 +1,33 @@ -from os import getenv -from rhubarb import DocAnalysis +from os import getenv + import boto3 -class ProcessDocument(): +from rhubarb import DocAnalysis + + +class ProcessDocument: def generateJSON(self, document): try: session = boto3.Session() - da = DocAnalysis(file_path=document, - boto3_session=session, - pages=[1]) - prompt="I want to extract the employee name, employee SSN, employee address, \ + da = DocAnalysis(file_path=document, boto3_session=session, pages=[1]) + prompt = "I want to extract the employee name, employee SSN, employee address, \ date of birth and phone number from this document." resp = da.generate_schema(message=prompt) - response = da.run(message=prompt, - output_schema=resp['output'] - ) + response = da.run(message=prompt, output_schema=resp["output"]) except ( # handle bedrock or S3 error - ) as e: - + ): return None return response - def lambda_handler(event, context): - BUCKET_NAME = getenv('BUCKET_NAME') + BUCKET_NAME = getenv("BUCKET_NAME") + + documentURL = f"s3://{BUCKET_NAME}/employee_enrollment.pdf" - documentURL = f's3://{BUCKET_NAME}/employee_enrollment.pdf' - result = ProcessDocument().generateJSON(documentURL) - return result \ No newline at end of file + return result diff --git a/sample_deployments/cdk_lambda/tests/unit/test_sample_deployments_stack.py b/sample_deployments/cdk_lambda/tests/unit/test_sample_deployments_stack.py index e8c260b..38e0043 100644 --- a/sample_deployments/cdk_lambda/tests/unit/test_sample_deployments_stack.py +++ b/sample_deployments/cdk_lambda/tests/unit/test_sample_deployments_stack.py @@ -1,14 +1,15 @@ import aws_cdk as core import aws_cdk.assertions as assertions - from sample_deployments.sample_deployments_stack import SampleDeploymentsStack + # example tests. To run these tests, uncomment this file along with the example # resource in sample_deployments/sample_deployments_stack.py def test_sqs_queue_created(): app = core.App() stack = SampleDeploymentsStack(app, "sample-deployments") - template = assertions.Template.from_stack(stack) + assertions.Template.from_stack(stack) + # template.has_resource_properties("AWS::SQS::Queue", { # "VisibilityTimeout": 300 diff --git a/src/rhubarb/analyze.py b/src/rhubarb/analyze.py index e7fe76a..8faa538 100644 --- a/src/rhubarb/analyze.py +++ b/src/rhubarb/analyze.py @@ -1,5 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0 import logging from typing import Any, List, Optional, Generator @@ -90,9 +90,7 @@ def validate_model(cls, values: dict) -> dict: if 0 in pages and len(pages) > 1: logger.error("If specific pages are provided, page number 0 is invalid.") - raise ValueError( - "If specific pages are provided, page number 0 is invalid." - ) + raise ValueError("If specific pages are provided, page number 0 is invalid.") if len(pages) > 20: logger.error("Cannot process more than 20 pages at a time.") @@ -101,9 +99,7 @@ def validate_model(cls, values: dict) -> dict: blocked_schemes = ["http://", "https://", "ftp://"] if any(file_path.startswith(scheme) for scheme in blocked_schemes): logger.error("file_path must be a local file system path or an s3:// path") - raise ValueError( - "file_path must be a local file system path or an s3:// path" - ) + raise ValueError("file_path must be a local file system path or an s3:// path") s3_config = Config( retries={"max_attempts": 0, "mode": "standard"}, signature_version="s3v4" @@ -150,10 +146,11 @@ def run( Args: - `message` (`str`): The input message or prompt for the language model. - `output_schema` (`Optional[dict]`, optional): The output JSON schema for the language model response. Defaults to None. - """ + """ if ( self.modelId == LanguageModels.CLAUDE_HAIKU_V1 or self.modelId == LanguageModels.CLAUDE_SONNET_V1 + or self.modelId == LanguageModels.CLAUDE_SONNET_V2 ): a_msg = self._get_anthropic_prompt( message=message, @@ -185,6 +182,7 @@ def run_stream( if ( self.modelId == LanguageModels.CLAUDE_HAIKU_V1 or self.modelId == LanguageModels.CLAUDE_SONNET_V1 + or self.modelId == LanguageModels.CLAUDE_SONNET_V2 ): a_msg = self._get_anthropic_prompt( message=message, sys_prompt=self.system_prompt, history=history @@ -209,6 +207,7 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any: if ( self.modelId == LanguageModels.CLAUDE_HAIKU_V1 or self.modelId == LanguageModels.CLAUDE_SONNET_V1 + or self.modelId == LanguageModels.CLAUDE_SONNET_V2 ): sys_prompt = SystemPrompts(entities=entities).NERSysPrompt a_msg = self._get_anthropic_prompt(message=message, sys_prompt=sys_prompt) @@ -220,9 +219,7 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any: response = model_invoke.invoke_model_json() return response - def generate_schema( - self, message: str, assistive_rephrase: Optional[bool] = False - ) -> dict: + def generate_schema(self, message: str, assistive_rephrase: Optional[bool] = False) -> dict: """ Invokes the specified language model with the given message to genereate a JSON schema for a given document. @@ -234,6 +231,7 @@ def generate_schema( if ( self.modelId == LanguageModels.CLAUDE_HAIKU_V1 or self.modelId == LanguageModels.CLAUDE_SONNET_V1 + or self.modelId == LanguageModels.CLAUDE_SONNET_V2 ): if assistive_rephrase: sys_prompt = SystemPrompts().SchemaGenSysPromptWithRephrase diff --git a/src/rhubarb/models.py b/src/rhubarb/models.py index bad41f4..211fdeb 100644 --- a/src/rhubarb/models.py +++ b/src/rhubarb/models.py @@ -8,6 +8,7 @@ class LanguageModels(Enum): CLAUDE_OPUS_V1 = "anthropic.claude-3-opus-20240229-v1:0" CLAUDE_SONNET_V1 = "anthropic.claude-3-sonnet-20240229-v1:0" CLAUDE_HAIKU_V1 = "anthropic.claude-3-haiku-20240307-v1:0" + CLAUDE_SONNET_V2 = "anthropic.claude-3-5-sonnet-20240620-v1:0" class EmbeddingModels(Enum):