Skip to content

Commit

Permalink
Merge pull request #9 from gowthamshankar99/main
Browse files Browse the repository at this point in the history
Add support for SONNET 3.5
  • Loading branch information
anjanvb authored Sep 2, 2024
2 parents 793704a + af1cc8b commit 2cc11e3
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 77 deletions.
2 changes: 1 addition & 1 deletion cookbooks/0-rhubarb-cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"source": [
"### Default Model\n",
"---\n",
"By default Rhubarb uses Claude Sonnet model, however you can also use Haiku or Opus (when available)."
"By default Rhubarb uses Claude Sonnet model, however you can also use Haiku, Sonnet 3.5 or Opus (when available)."
]
},
{
Expand Down
16 changes: 5 additions & 11 deletions sample_deployments/cdk_lambda/app.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
#!/usr/bin/env python3
import os

import aws_cdk as cdk

from infra.lambda_document_proccessing import SampleDeploymentsStack


app = cdk.App()
SampleDeploymentsStack(app, "SampleDeploymentsStack",
SampleDeploymentsStack(
app,
"SampleDeploymentsStack",
# If you don't specify 'env', this stack will be environment-agnostic.
# Account/Region-dependent features and context lookups will not work,
# but a single synthesized template can be deployed anywhere.

# Uncomment the next line to specialize this stack for the AWS Account
# and Region that are implied by the current CLI configuration.

#env=cdk.Environment(account=os.getenv('CDK_DEFAULT_ACCOUNT'), region=os.getenv('CDK_DEFAULT_REGION')),

# env=cdk.Environment(account=os.getenv('CDK_DEFAULT_ACCOUNT'), region=os.getenv('CDK_DEFAULT_REGION')),
# Uncomment the next line if you know exactly what Account and Region you
# want to deploy the stack to. */

#env=cdk.Environment(account='123456789012', region='us-east-1'),

# env=cdk.Environment(account='123456789012', region='us-east-1'),
# For more information, see https://docs.aws.amazon.com/cdk/latest/guide/environments.html
)

Expand Down
67 changes: 32 additions & 35 deletions sample_deployments/cdk_lambda/infra/lambda_document_proccessing.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,55 @@
from aws_cdk import (
Duration,
RemovalPolicy,
Stack,
aws_lambda,
aws_iam,
aws_s3
)
from pathlib import Path

from aws_cdk import Stack, Duration, RemovalPolicy, aws_s3, aws_iam, aws_lambda
from constructs import Construct
from pathlib import Path

class SampleDeploymentsStack(Stack):

class SampleDeploymentsStack(Stack):
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
super().__init__(scope, construct_id, **kwargs)

bucket = aws_s3.Bucket(self, "MyLambdaBucket",
removal_policy=RemovalPolicy.DESTROY,
auto_delete_objects=True
bucket = aws_s3.Bucket(
self, "MyLambdaBucket", removal_policy=RemovalPolicy.DESTROY, auto_delete_objects=True
)


lambda_function = self.__create_lambda_function("document-processing-lambda", bucket)
lambda_function = self.__create_lambda_function("document-processing-lambda", bucket)
srole_bedrock = aws_iam.Role(
scope = self,
id = f'srole-bedrock',
assumed_by = aws_iam.ServicePrincipal('bedrock.amazonaws.com')
scope=self,
id="srole-bedrock",
assumed_by=aws_iam.ServicePrincipal("bedrock.amazonaws.com"),
)

srole_bedrock.grant_pass_role(lambda_function)

lambda_function.role.add_managed_policy(
aws_iam.ManagedPolicy.from_aws_managed_policy_name('AmazonBedrockFullAccess')
aws_iam.ManagedPolicy.from_aws_managed_policy_name("AmazonBedrockFullAccess")
)

bucket.grant_read_write(lambda_function.role)

def __create_lambda_function(self, function_name: str, bucket: aws_s3.Bucket):

lambda_role = aws_iam.Role(self, 'LambdaExecutionRole',
assumed_by=aws_iam.ServicePrincipal('lambda.amazonaws.com'),
managed_policies=[aws_iam.ManagedPolicy.from_aws_managed_policy_name('service-role/AWSLambdaBasicExecutionRole')]
lambda_role = aws_iam.Role(
self,
"LambdaExecutionRole",
assumed_by=aws_iam.ServicePrincipal("lambda.amazonaws.com"),
managed_policies=[
aws_iam.ManagedPolicy.from_aws_managed_policy_name(
"service-role/AWSLambdaBasicExecutionRole"
)
],
)

lambda_function = aws_lambda.DockerImageFunction(
scope=self,
id=function_name,
function_name=function_name,
code=aws_lambda.DockerImageCode.from_image_asset(directory=f"{Path('source/lambda').absolute()}"),
timeout=Duration.minutes(15),
memory_size=3000,
role=lambda_role,
environment={
'BUCKET_NAME': bucket.bucket_name
}
scope=self,
id=function_name,
function_name=function_name,
code=aws_lambda.DockerImageCode.from_image_asset(
directory=f"{Path('source/lambda').absolute()}"
),
timeout=Duration.minutes(15),
memory_size=3000,
role=lambda_role,
environment={"BUCKET_NAME": bucket.bucket_name},
)

return lambda_function
31 changes: 14 additions & 17 deletions sample_deployments/cdk_lambda/source/lambda/handler.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
from os import getenv
from rhubarb import DocAnalysis
from os import getenv

import boto3

class ProcessDocument():
from rhubarb import DocAnalysis


class ProcessDocument:
def generateJSON(self, document):
try:
session = boto3.Session()

da = DocAnalysis(file_path=document,
boto3_session=session,
pages=[1])
prompt="I want to extract the employee name, employee SSN, employee address, \
da = DocAnalysis(file_path=document, boto3_session=session, pages=[1])
prompt = "I want to extract the employee name, employee SSN, employee address, \
date of birth and phone number from this document."
resp = da.generate_schema(message=prompt)
response = da.run(message=prompt,
output_schema=resp['output']
)
response = da.run(message=prompt, output_schema=resp["output"])
except (
# handle bedrock or S3 error
) as e:

):
return None

return response



def lambda_handler(event, context):
BUCKET_NAME = getenv('BUCKET_NAME')
BUCKET_NAME = getenv("BUCKET_NAME")

documentURL = f"s3://{BUCKET_NAME}/employee_enrollment.pdf"

documentURL = f's3://{BUCKET_NAME}/employee_enrollment.pdf'

result = ProcessDocument().generateJSON(documentURL)

return result
return result
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import aws_cdk as core
import aws_cdk.assertions as assertions

from sample_deployments.sample_deployments_stack import SampleDeploymentsStack


# example tests. To run these tests, uncomment this file along with the example
# resource in sample_deployments/sample_deployments_stack.py
def test_sqs_queue_created():
app = core.App()
stack = SampleDeploymentsStack(app, "sample-deployments")
template = assertions.Template.from_stack(stack)
assertions.Template.from_stack(stack)


# template.has_resource_properties("AWS::SQS::Queue", {
# "VisibilityTimeout": 300
Expand Down
20 changes: 9 additions & 11 deletions src/rhubarb/analyze.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Any, List, Optional, Generator
Expand Down Expand Up @@ -90,9 +90,7 @@ def validate_model(cls, values: dict) -> dict:

if 0 in pages and len(pages) > 1:
logger.error("If specific pages are provided, page number 0 is invalid.")
raise ValueError(
"If specific pages are provided, page number 0 is invalid."
)
raise ValueError("If specific pages are provided, page number 0 is invalid.")

if len(pages) > 20:
logger.error("Cannot process more than 20 pages at a time.")
Expand All @@ -101,9 +99,7 @@ def validate_model(cls, values: dict) -> dict:
blocked_schemes = ["http://", "https://", "ftp://"]
if any(file_path.startswith(scheme) for scheme in blocked_schemes):
logger.error("file_path must be a local file system path or an s3:// path")
raise ValueError(
"file_path must be a local file system path or an s3:// path"
)
raise ValueError("file_path must be a local file system path or an s3:// path")

s3_config = Config(
retries={"max_attempts": 0, "mode": "standard"}, signature_version="s3v4"
Expand Down Expand Up @@ -150,10 +146,11 @@ def run(
Args:
- `message` (`str`): The input message or prompt for the language model.
- `output_schema` (`Optional[dict]`, optional): The output JSON schema for the language model response. Defaults to None.
"""
"""
if (
self.modelId == LanguageModels.CLAUDE_HAIKU_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
):
a_msg = self._get_anthropic_prompt(
message=message,
Expand Down Expand Up @@ -185,6 +182,7 @@ def run_stream(
if (
self.modelId == LanguageModels.CLAUDE_HAIKU_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
):
a_msg = self._get_anthropic_prompt(
message=message, sys_prompt=self.system_prompt, history=history
Expand All @@ -209,6 +207,7 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any:
if (
self.modelId == LanguageModels.CLAUDE_HAIKU_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
):
sys_prompt = SystemPrompts(entities=entities).NERSysPrompt
a_msg = self._get_anthropic_prompt(message=message, sys_prompt=sys_prompt)
Expand All @@ -220,9 +219,7 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any:
response = model_invoke.invoke_model_json()
return response

def generate_schema(
self, message: str, assistive_rephrase: Optional[bool] = False
) -> dict:
def generate_schema(self, message: str, assistive_rephrase: Optional[bool] = False) -> dict:
"""
Invokes the specified language model with the given message to genereate a JSON
schema for a given document.
Expand All @@ -234,6 +231,7 @@ def generate_schema(
if (
self.modelId == LanguageModels.CLAUDE_HAIKU_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
):
if assistive_rephrase:
sys_prompt = SystemPrompts().SchemaGenSysPromptWithRephrase
Expand Down
1 change: 1 addition & 0 deletions src/rhubarb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class LanguageModels(Enum):
CLAUDE_OPUS_V1 = "anthropic.claude-3-opus-20240229-v1:0"
CLAUDE_SONNET_V1 = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_HAIKU_V1 = "anthropic.claude-3-haiku-20240307-v1:0"
CLAUDE_SONNET_V2 = "anthropic.claude-3-5-sonnet-20240620-v1:0"


class EmbeddingModels(Enum):
Expand Down

0 comments on commit 2cc11e3

Please sign in to comment.