Skip to content

Commit

Permalink
Add pii redaction
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 10, 2023
1 parent 605af05 commit 0a765b4
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 37 deletions.
12 changes: 8 additions & 4 deletions examples/amazon_comprehend_pii_moderation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ func main() {

client := comprehend.NewFromConfig(cfg)

moderationChain := moderation.NewAmazonComprehendPII(client)
moderationChain := moderation.NewAmazonComprehendPII(client, func(o *moderation.AmazonComprehendPIIOptions) {
o.Redact = true
})

result, err := golc.SimpleCall(context.Background(), moderationChain, "My Name is Alfred E. Neuman")
input := "My Name is Alfred E. Neuman"

result, err := golc.SimpleCall(context.Background(), moderationChain, input)
if err != nil {
log.Fatal(err) // pii content found
log.Fatal(err)
}

fmt.Println(result)
fmt.Println(input, " -> ", result)
}
108 changes: 83 additions & 25 deletions moderation/amazon_comprehend_pii.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package moderation
import (
"context"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/comprehend"
Expand All @@ -17,8 +18,13 @@ import (
type AmazonComprehendPIIClient interface {
// ContainsPiiEntities is an interface method that checks if the input text contains Personally Identifiable Information (PII) entities.
ContainsPiiEntities(ctx context.Context, params *comprehend.ContainsPiiEntitiesInput, optFns ...func(*comprehend.Options)) (*comprehend.ContainsPiiEntitiesOutput, error)
// DetectPiiEntities is an interface method that checks if the input text contains Personally Identifiable Information (PII) entities and returns information about them.
DetectPiiEntities(ctx context.Context, params *comprehend.DetectPiiEntitiesInput, optFns ...func(*comprehend.Options)) (*comprehend.DetectPiiEntitiesOutput, error)
}

// ReactFunc is a function type that defines how to react to PII entities found in the text.
type RedactFunc func(ctx context.Context, text string, maskMarker rune, entityType string, offsetBegin, offsetEnd int32) string

// AmazonComprehendPIIOptions contains options for the Amazon Comprehend PII moderation.
type AmazonComprehendPIIOptions struct {
// CallbackOptions embeds CallbackOptions to include the verbosity setting and callbacks.
Expand All @@ -33,6 +39,12 @@ type AmazonComprehendPIIOptions struct {
Labels []string
// Threshold is the threshold for determining if PII content is found.
Threshold float32
// Redact enables redaction of detected PII entities.
Redact bool
// MaskMarker is the redaction mask character in case redaction
MaskMarker rune
// RedactFunc defines how to redact PII entities found in the text.
RedactFunc RedactFunc
}

// AmazonComprehendPII is a struct representing the Amazon Comprehend PII moderation functionality.
Expand All @@ -51,6 +63,14 @@ func NewAmazonComprehendPII(client AmazonComprehendPIIClient, optFns ...func(o *
OutputKey: "output",
LanguageCode: "en",
Threshold: 0.8,
Redact: false,
MaskMarker: '*',
RedactFunc: func(ctx context.Context, text string, maskMarker rune, entityType string, offsetBegin, offsetEnd int32) string {
maskLength := offsetEnd - offsetBegin
maskedPart := strings.Repeat(string(maskMarker), int(maskLength))

return text[:offsetBegin] + maskedPart + text[offsetEnd:]
},
}

for _, fn := range optFns {
Expand Down Expand Up @@ -85,33 +105,11 @@ func (c *AmazonComprehendPII) Call(ctx context.Context, inputs schema.ChainValue
return nil, cbErr
}

output, err := c.client.ContainsPiiEntities(ctx, &comprehend.ContainsPiiEntitiesInput{
Text: aws.String(text),
LanguageCode: types.LanguageCode(c.opts.LanguageCode),
})
if err != nil {
return nil, err
if !c.opts.Redact {
return c.containsPII(ctx, text)
}

if len(c.opts.Labels) == 0 {
for _, label := range output.Labels {
if aws.ToFloat32(label.Score) >= c.opts.Threshold {
return nil, errors.New("pii content found")
}
}
} else {
for _, label := range output.Labels {
if util.Contains(c.opts.Labels, string(label.Name)) {
if aws.ToFloat32(label.Score) >= c.opts.Threshold {
return nil, errors.New("pii content found")
}
}
}
}

return schema.ChainValues{
c.opts.OutputKey: text,
}, nil
return c.detectPII(ctx, text)
}

// Memory returns the memory associated with the chain.
Expand Down Expand Up @@ -143,3 +141,63 @@ func (c *AmazonComprehendPII) InputKeys() []string {
func (c *AmazonComprehendPII) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *AmazonComprehendPII) containsPII(ctx context.Context, text string) (schema.ChainValues, error) {
output, err := c.client.ContainsPiiEntities(ctx, &comprehend.ContainsPiiEntitiesInput{
Text: aws.String(text),
LanguageCode: types.LanguageCode(c.opts.LanguageCode),
})
if err != nil {
return nil, err
}

if len(c.opts.Labels) == 0 {
for _, label := range output.Labels {
if aws.ToFloat32(label.Score) >= c.opts.Threshold {
return nil, errors.New("pii content found")
}
}
} else {
for _, label := range output.Labels {
if util.Contains(c.opts.Labels, string(label.Name)) {
if aws.ToFloat32(label.Score) >= c.opts.Threshold {
return nil, errors.New("pii content found")
}
}
}
}

return schema.ChainValues{
c.opts.OutputKey: text,
}, nil
}

func (c *AmazonComprehendPII) detectPII(ctx context.Context, text string) (schema.ChainValues, error) {
output, err := c.client.DetectPiiEntities(ctx, &comprehend.DetectPiiEntitiesInput{
Text: aws.String(text),
LanguageCode: types.LanguageCode(c.opts.LanguageCode),
})
if err != nil {
return nil, err
}

if len(c.opts.Labels) == 0 {
for _, entity := range output.Entities {
if aws.ToFloat32(entity.Score) >= c.opts.Threshold {
text = c.opts.RedactFunc(ctx, text, c.opts.MaskMarker, string(entity.Type), aws.ToInt32(entity.BeginOffset), aws.ToInt32(entity.EndOffset))
}
}
} else {
for _, entity := range output.Entities {
if util.Contains(c.opts.Labels, string(entity.Type)) {
if aws.ToFloat32(entity.Score) >= c.opts.Threshold {
text = c.opts.RedactFunc(ctx, text, c.opts.MaskMarker, string(entity.Type), aws.ToInt32(entity.BeginOffset), aws.ToInt32(entity.EndOffset))
}
}
}
}

return schema.ChainValues{
c.opts.OutputKey: text,
}, nil
}
42 changes: 34 additions & 8 deletions moderation/amazon_comprehend_pii_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package moderation

import (
"context"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -16,17 +17,30 @@ func TestAmazonComprehendPII(t *testing.T) {
testCases := []struct {
name string
inputText string
redact bool
expectedError string
expectedText string
}{
{
name: "Moderation Passed",
inputText: "nopii",
inputText: "harmless content",
redact: false,
expectedError: "",
expectedText: "harmless content",
},
{
name: "Moderation Failed",
inputText: "pii",
redact: false,
expectedError: "pii content found",
expectedText: "",
},
{
name: "Redacted",
inputText: "hello pii",
redact: true,
expectedError: "",
expectedText: "hello ***",
},
}

Expand All @@ -36,18 +50,25 @@ func TestAmazonComprehendPII(t *testing.T) {
ctx := context.Background()

score := float32(0.1)
if tc.inputText == "pii" {
if strings.Contains(tc.inputText, "pii") {
score = 0.9
}

fakeClient := &fakeAmazonComprehendPIIClient{
response: &comprehend.ContainsPiiEntitiesOutput{
containsResponse: &comprehend.ContainsPiiEntitiesOutput{
Labels: []types.EntityLabel{
{Name: types.PiiEntityTypeName, Score: aws.Float32(score)},
},
},
detectResponse: &comprehend.DetectPiiEntitiesOutput{
Entities: []types.PiiEntity{
{Type: types.PiiEntityTypeName, Score: aws.Float32(score), BeginOffset: aws.Int32(6), EndOffset: aws.Int32(9)},
},
},
}
chain := NewAmazonComprehendPII(fakeClient)
chain := NewAmazonComprehendPII(fakeClient, func(o *AmazonComprehendPIIOptions) {
o.Redact = tc.redact
})

// Test
inputs := schema.ChainValues{
Expand All @@ -59,7 +80,7 @@ func TestAmazonComprehendPII(t *testing.T) {
if tc.expectedError == "" {
assert.NoError(t, err)
assert.NotNil(t, outputs)
assert.Equal(t, tc.inputText, outputs["output"])
assert.Equal(t, tc.expectedText, outputs["output"])
} else {
assert.Nil(t, outputs)
assert.Error(t, err)
Expand All @@ -70,10 +91,15 @@ func TestAmazonComprehendPII(t *testing.T) {
}

type fakeAmazonComprehendPIIClient struct {
response *comprehend.ContainsPiiEntitiesOutput
err error
containsResponse *comprehend.ContainsPiiEntitiesOutput
detectResponse *comprehend.DetectPiiEntitiesOutput
err error
}

func (c *fakeAmazonComprehendPIIClient) ContainsPiiEntities(ctx context.Context, params *comprehend.ContainsPiiEntitiesInput, optFns ...func(*comprehend.Options)) (*comprehend.ContainsPiiEntitiesOutput, error) {
return c.response, c.err
return c.containsResponse, c.err
}

func (c *fakeAmazonComprehendPIIClient) DetectPiiEntities(ctx context.Context, params *comprehend.DetectPiiEntitiesInput, optFns ...func(*comprehend.Options)) (*comprehend.DetectPiiEntitiesOutput, error) {
return c.detectResponse, c.err
}

0 comments on commit 0a765b4

Please sign in to comment.