From 402943709f0e3e36827e70f648cb8afe02f3d860 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 11 Sep 2024 21:01:16 +0000 Subject: [PATCH 1/2] cohere: introduce cohere processors support for chat and embed processors --- .../pages/processors/cohere_chat.adoc | 594 ++++++++++++++++++ .../pages/processors/cohere_embeddings.adoc | 151 +++++ go.mod | 1 + go.sum | 2 + internal/impl/cohere/base_processor.go | 65 ++ internal/impl/cohere/chat_processor.go | 356 +++++++++++ internal/impl/cohere/embeddings_processor.go | 162 +++++ internal/impl/cohere/json_schema_provider.go | 84 +++ internal/plugins/info.csv | 2 + public/components/all/package.go | 1 + public/components/cloud/package.go | 1 + public/components/cohere/package.go | 14 + 12 files changed, 1433 insertions(+) create mode 100644 docs/modules/components/pages/processors/cohere_chat.adoc create mode 100644 docs/modules/components/pages/processors/cohere_embeddings.adoc create mode 100644 internal/impl/cohere/base_processor.go create mode 100644 internal/impl/cohere/chat_processor.go create mode 100644 internal/impl/cohere/embeddings_processor.go create mode 100644 internal/impl/cohere/json_schema_provider.go create mode 100644 public/components/cohere/package.go diff --git a/docs/modules/components/pages/processors/cohere_chat.adoc b/docs/modules/components/pages/processors/cohere_chat.adoc new file mode 100644 index 000000000..e0675deb8 --- /dev/null +++ b/docs/modules/components/pages/processors/cohere_chat.adoc @@ -0,0 +1,594 @@ += cohere_chat +:type: processor +:status: experimental +:categories: ["AI"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Generates responses to messages in a chat conversation, using the Cohere API. + +Introduced in version 4.37.0. + + +[tabs] +====== +Common:: ++ +-- + +```yml +# Common config fields, showing default values +label: "" +cohere_chat: + base_url: https://api.cohere.com + auth_token: "" # No default (required) + model: command-r-plus # No default (required) + prompt: "" # No default (optional) + system_prompt: "" # No default (optional) + max_tokens: 0 # No default (optional) + temperature: 0 # No default (optional) + response_format: text + json_schema: "" # No default (optional) +``` + +-- +Advanced:: ++ +-- + +```yml +# All config fields, showing default values +label: "" +cohere_chat: + base_url: https://api.cohere.com + auth_token: "" # No default (required) + model: command-r-plus # No default (required) + prompt: "" # No default (optional) + system_prompt: "" # No default (optional) + max_tokens: 0 # No default (optional) + temperature: 0 # No default (optional) + response_format: text + json_schema: "" # No default (optional) + schema_registry: + url: "" # No default (required) + subject: "" # No default (required) + refresh_interval: "" # No default (optional) + tls: + skip_cert_verify: false + enable_renegotiation: false + root_cas: "" + root_cas_file: "" + client_certs: [] + oauth: + enabled: false + consumer_key: "" + consumer_secret: "" + access_token: "" + access_token_secret: "" + basic_auth: + enabled: false + username: "" + password: "" + jwt: + enabled: false + private_key_file: "" + signing_method: "" + claims: {} + headers: {} + top_p: 0 # No default (optional) + frequency_penalty: 0 # No default (optional) + presence_penalty: 0 # No default (optional) + seed: 0 # No default (optional) + stop: [] # No default (optional) +``` + +-- +====== + +This processor sends the contents of user prompts to the Cohere API, which generates responses. By default, the processor submits the entire payload of each message as a string, unless you use the `prompt` configuration field to customize it. + +To learn more about chat completion, see the https://docs.cohere.com/docs/chat-api[Cohere API documentation^]. + +== Fields + +=== `base_url` + +The base URL to use for API requests. + + +*Type*: `string` + +*Default*: `"https://api.cohere.com"` + +=== `auth_token` + +The auth token for the Cohere API. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + + +=== `model` + +The name of the Cohere model to use. + + +*Type*: `string` + + +```yml +# Examples + +model: command-r-plus + +model: command-r + +model: command + +model: command-light +``` + +=== `prompt` + +The user prompt you want to generate a response for. By default, the processor submits the entire payload as a string. +This field supports xref:configuration:interpolation.adoc#bloblang-queries[interpolation functions]. + + +*Type*: `string` + + +=== `system_prompt` + +The system prompt to submit along with the user prompt. +This field supports xref:configuration:interpolation.adoc#bloblang-queries[interpolation functions]. + + +*Type*: `string` + + +=== `max_tokens` + +The maximum number of tokens that can be generated in the chat completion. + + +*Type*: `int` + + +=== `temperature` + +What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + +We generally recommend altering this or top_p but not both. + + +*Type*: `float` + + +=== `response_format` + +Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured. + + +*Type*: `string` + +*Default*: `"text"` + +Options: +`text` +, `json` +, `json_schema` +. + +=== `json_schema` + +The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://docs.cohere.com/docs/structured-outputs-json[Cohere documentation^]. + + +*Type*: `string` + + +=== `schema_registry` + +The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://docs.cohere.com/docs/structured-outputs-json[Cohere documentation^]. + + +*Type*: `object` + + +=== `schema_registry.url` + +The base URL of the schema registry service. + + +*Type*: `string` + + +=== `schema_registry.subject` + +The subject name to fetch the schema for. + + +*Type*: `string` + + +=== `schema_registry.refresh_interval` + +The refresh rate for getting the latest schema. If not specified the schema does not refresh. + + +*Type*: `string` + + +=== `schema_registry.tls` + +Custom TLS settings can be used to override system defaults. + + +*Type*: `object` + + +=== `schema_registry.tls.skip_cert_verify` + +Whether to skip server side certificate verification. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.tls.enable_renegotiation` + +Whether to allow the remote server to repeatedly request renegotiation. Enable this option if you're seeing the error message `local error: tls: no renegotiation`. + + +*Type*: `bool` + +*Default*: `false` +Requires version 3.45.0 or newer + +=== `schema_registry.tls.root_cas` + +An optional root certificate authority to use. This is a string, representing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas: |- + -----BEGIN CERTIFICATE----- + ... + -----END CERTIFICATE----- +``` + +=== `schema_registry.tls.root_cas_file` + +An optional path of a root certificate authority file to use. This is a file, often with a .pem extension, containing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas_file: ./root_cas.pem +``` + +=== `schema_registry.tls.client_certs` + +A list of client certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both. + + +*Type*: `array` + +*Default*: `[]` + +```yml +# Examples + +client_certs: + - cert: foo + key: bar + +client_certs: + - cert_file: ./example.pem + key_file: ./example.key +``` + +=== `schema_registry.tls.client_certs[].cert` + +A plain text certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key` + +A plain text certificate key to use. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].cert_file` + +The path of a certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key_file` + +The path of a certificate key to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].password` + +A plain text password for when the private key is password encrypted in PKCS#1 or PKCS#8 format. The obsolete `pbeWithMD5AndDES-CBC` algorithm is not supported for the PKCS#8 format. + +Because the obsolete pbeWithMD5AndDES-CBC algorithm does not authenticate the ciphertext, it is vulnerable to padding oracle attacks that can let an attacker recover the plaintext. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +password: foo + +password: ${KEY_PASSWORD} +``` + +=== `schema_registry.oauth` + +Allows you to specify open authentication via OAuth version 1. + + +*Type*: `object` + + +=== `schema_registry.oauth.enabled` + +Whether to use OAuth version 1 in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.oauth.consumer_key` + +A value used to identify the client to the service provider. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.consumer_secret` + +A secret used to establish ownership of the consumer key. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token` + +A value used to gain access to the protected resources on behalf of the user. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token_secret` + +A secret provided in order to establish ownership of a given access token. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth` + +Allows you to specify basic authentication. + + +*Type*: `object` + + +=== `schema_registry.basic_auth.enabled` + +Whether to use basic authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.basic_auth.username` + +A username to authenticate as. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth.password` + +A password to authenticate with. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt` + +BETA: Allows you to specify JWT authentication. + + +*Type*: `object` + + +=== `schema_registry.jwt.enabled` + +Whether to use JWT authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.jwt.private_key_file` + +A file with the PEM encoded via PKCS1 or PKCS8 as private key. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.signing_method` + +A method used to sign the token such as RS256, RS384, RS512 or EdDSA. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.claims` + +A value used to identify the claims that issued the JWT. + + +*Type*: `object` + +*Default*: `{}` + +=== `schema_registry.jwt.headers` + +Add optional key/value headers to the JWT. + + +*Type*: `object` + +*Default*: `{}` + +=== `top_p` + +An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + +We generally recommend altering this or temperature but not both. + + +*Type*: `float` + + +=== `frequency_penalty` + +Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + + +*Type*: `float` + + +=== `presence_penalty` + +Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + + +*Type*: `float` + + +=== `seed` + +If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed. + + +*Type*: `int` + + +=== `stop` + +Up to 4 sequences where the API will stop generating further tokens. + + +*Type*: `array` + + + diff --git a/docs/modules/components/pages/processors/cohere_embeddings.adoc b/docs/modules/components/pages/processors/cohere_embeddings.adoc new file mode 100644 index 000000000..247e33bf9 --- /dev/null +++ b/docs/modules/components/pages/processors/cohere_embeddings.adoc @@ -0,0 +1,151 @@ += cohere_embeddings +:type: processor +:status: experimental +:categories: ["AI"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Generates vector embeddings to represent input text, using the Cohere API. + +Introduced in version 4.37.0. + +```yml +# Config fields, showing default values +label: "" +cohere_embeddings: + base_url: https://api.cohere.com + auth_token: "" # No default (required) + model: embed-english-v3.0 # No default (required) + text_mapping: "" # No default (optional) + dimensions: search_document +``` + +This processor sends text strings to the Cohere API, which generates vector embeddings. By default, the processor submits the entire payload of each message as a string, unless you use the `text_mapping` configuration field to customize it. + +To learn more about vector embeddings, see the https://docs.cohere.com/docs/embeddings[Cohere API documentation^]. + +== Examples + +[tabs] +====== +Store embedding vectors in Qdrant:: ++ +-- + +Compute embeddings for some generated data and store it within xrefs:component:outputs/qdrant.adoc[Qdrant] + +```yamlinput: + generate: + interval: 1s + mapping: | + root = {"text": fake("paragraph")} +pipeline: + processors: + - cohere_embeddings: + model: embed-english-v3 + auth_token: "${COHERE_AUTH_TOKEN}" + text_mapping: "root = this.text" +output: + qdrant: + grpc_host: localhost:6334 + collection_name: "example_collection" + id: "root = uuid_v4()" + vector_mapping: "root = this"``` + +-- +====== + +== Fields + +=== `base_url` + +The base URL to use for API requests. + + +*Type*: `string` + +*Default*: `"https://api.cohere.com"` + +=== `auth_token` + +The auth token for the Cohere API. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + + +=== `model` + +The name of the Cohere model to use. + + +*Type*: `string` + + +```yml +# Examples + +model: embed-english-v3.0 + +model: embed-english-light-v3.0 + +model: embed-multilingual-v3.0 + +model: embed-multilingual-light-v3.0 +``` + +=== `text_mapping` + +The text you want to generate a vector embedding for. By default, the processor submits the entire payload as a string. + + +*Type*: `string` + + +=== `dimensions` + +Specifies the type of input passed to the model. + + +*Type*: `string` + +*Default*: `"search_document"` + +|=== +| Option | Summary + +| `classification` +| Used for embeddings passed through a text classifier. +| `clustering` +| Used for the embeddings run through a clustering algorithm. +| `search_document` +| Used for embeddings stored in a vector database for search use-cases. +| `search_query` +| Used for embeddings of search queries run against a vector DB to find relevant documents. + +|=== + + diff --git a/go.mod b/go.mod index fd1ef8482..a58468b51 100644 --- a/go.mod +++ b/go.mod @@ -208,6 +208,7 @@ require ( github.com/bufbuild/protocompile v0.10.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cockroachdb/apd/v3 v3.2.1 // indirect + github.com/cohere-ai/cohere-go/v2 v2.11.0 github.com/containerd/containerd v1.7.18 // indirect github.com/containerd/continuity v0.4.3 // indirect github.com/containerd/errdefs v0.1.0 // indirect diff --git a/go.sum b/go.sum index 5f10ecb05..d8a2bd902 100644 --- a/go.sum +++ b/go.sum @@ -338,6 +338,8 @@ github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= +github.com/cohere-ai/cohere-go/v2 v2.11.0 h1:9Tn+v3dnKGKH1RHrwpi81/aLbz9OottTu5+uuW/qgz0= +github.com/cohere-ai/cohere-go/v2 v2.11.0/go.mod h1:MuiJkCxlR18BDV2qQPbz2Yb/OCVphT1y6nD2zYaKeR0= github.com/colinmarc/hdfs v1.1.3 h1:662salalXLFmp+ctD+x0aG+xOg62lnVnOJHksXYpFBw= github.com/colinmarc/hdfs v1.1.3/go.mod h1:0DumPviB681UcSuJErAbDIOx6SIaJWj463TymfZG02I= github.com/colinmarc/hdfs/v2 v2.1.1/go.mod h1:M3x+k8UKKmxtFu++uAZ0OtDU8jR3jnaZIAc6yK4Ue0c= diff --git a/internal/impl/cohere/base_processor.go b/internal/impl/cohere/base_processor.go new file mode 100644 index 000000000..46419a6d0 --- /dev/null +++ b/internal/impl/cohere/base_processor.go @@ -0,0 +1,65 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package cohere + +import ( + "context" + + cohere "github.com/cohere-ai/cohere-go/v2/client" + "github.com/redpanda-data/benthos/v4/public/service" +) + +const ( + cpFieldBaseURL = "base_url" + cpFieldAuthToken = "auth_token" + cpFieldModel = "model" +) + +func baseConfigFieldsWithModels(modelExamples ...any) []*service.ConfigField { + return []*service.ConfigField{ + service.NewStringField(cpFieldBaseURL). + Description("The base URL to use for API requests."). + Default("https://api.cohere.com"), + service.NewStringField(cpFieldAuthToken). + Secret(). + Description("The auth token for the Cohere API."), + service.NewStringField(cpFieldModel). + Description("The name of the Cohere model to use."). + Examples(modelExamples...), + } +} + +type baseProcessor struct { + client *cohere.Client + model string +} + +func (b *baseProcessor) Close(ctx context.Context) error { + return nil +} + +func newBaseProcessor(conf *service.ParsedConfig) (*baseProcessor, error) { + bu, err := conf.FieldString(cpFieldBaseURL) + if err != nil { + return nil, err + } + k, err := conf.FieldString(cpFieldAuthToken) + if err != nil { + return nil, err + } + c := cohere.NewClient( + cohere.WithBaseURL(bu), + cohere.WithToken(k), + ) + m, err := conf.FieldString(cpFieldModel) + if err != nil { + return nil, err + } + return &baseProcessor{c, m}, nil +} diff --git a/internal/impl/cohere/chat_processor.go b/internal/impl/cohere/chat_processor.go new file mode 100644 index 000000000..dd0a13cb1 --- /dev/null +++ b/internal/impl/cohere/chat_processor.go @@ -0,0 +1,356 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package cohere + +import ( + "context" + "fmt" + "math" + "slices" + "time" + + cohere "github.com/cohere-ai/cohere-go/v2" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" +) + +const ( + ccpFieldUserPrompt = "prompt" + ccpFieldSystemPrompt = "system_prompt" + ccpFieldMaxTokens = "max_tokens" + ccpFieldTemp = "temperature" + ccpFieldTopP = "top_p" + ccpFieldSeed = "seed" + ccpFieldStop = "stop" + ccpFieldPresencePenalty = "presence_penalty" + ccpFieldFrequencyPenalty = "frequency_penalty" + ccpFieldResponseFormat = "response_format" + // JSON schema fields + ccpFieldJSONSchema = "json_schema" + // Schema registry fields + ccpFieldSchemaRegistry = "schema_registry" + ccpFieldSchemaRegistrySubject = "subject" + ccpFieldSchemaRegistryRefreshInterval = "refresh_interval" + ccpFieldSchemaRegistryURL = "url" + ccpFieldSchemaRegistryTLS = "tls" +) + +func init() { + err := service.RegisterProcessor( + "cohere_chat", + chatProcessorConfig(), + makeChatProcessor, + ) + if err != nil { + panic(err) + } +} + +func chatProcessorConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Categories("AI"). + Summary("Generates responses to messages in a chat conversation, using the Cohere API."). + Description(` +This processor sends the contents of user prompts to the Cohere API, which generates responses. By default, the processor submits the entire payload of each message as a string, unless you use the `+"`"+ccpFieldUserPrompt+"`"+` configuration field to customize it. + +To learn more about chat completion, see the https://docs.cohere.com/docs/chat-api[Cohere API documentation^].`). + Version("4.37.0"). + Fields( + baseConfigFieldsWithModels( + "command-r-plus", + "command-r", + "command", + "command-light", + )..., + ). + Fields( + service.NewInterpolatedStringField(ccpFieldUserPrompt). + Description("The user prompt you want to generate a response for. By default, the processor submits the entire payload as a string."). + Optional(), + service.NewInterpolatedStringField(ccpFieldSystemPrompt). + Description("The system prompt to submit along with the user prompt."). + Optional(), + service.NewIntField(ccpFieldMaxTokens). + Optional(). + Description("The maximum number of tokens that can be generated in the chat completion."), + service.NewFloatField(ccpFieldTemp). + Optional(). + Description(`What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + +We generally recommend altering this or top_p but not both.`). + LintRule(`root = if this > 2 || this < 0 { [ "field must be between 0 and 2" ] }`), + service.NewStringEnumField(ccpFieldResponseFormat, "text", "json", "json_schema"). + Default("text"). + Description("Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured."), + service.NewStringField(ccpFieldJSONSchema). + Optional(). + Description("The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://docs.cohere.com/docs/structured-outputs-json[Cohere documentation^]."), + service.NewObjectField( + ccpFieldSchemaRegistry, + slices.Concat( + []*service.ConfigField{ + service.NewURLField(ccpFieldSchemaRegistryURL).Description("The base URL of the schema registry service."), + service.NewStringField(ccpFieldSchemaRegistrySubject). + Description("The subject name to fetch the schema for."), + service.NewDurationField(ccpFieldSchemaRegistryRefreshInterval). + Optional(). + Description("The refresh rate for getting the latest schema. If not specified the schema does not refresh."), + service.NewTLSField(ccpFieldSchemaRegistryTLS), + }, + service.NewHTTPRequestAuthSignerFields(), + )..., + ). + Description("The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://docs.cohere.com/docs/structured-outputs-json[Cohere documentation^]."). + Optional(). + Advanced(), + service.NewFloatField(ccpFieldTopP). + Optional(). + Advanced(). + Description(`An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + +We generally recommend altering this or temperature but not both.`). + LintRule(`root = if this > 1 || this < 0 { [ "field must be between 0 and 1" ] }`), + service.NewFloatField(ccpFieldFrequencyPenalty). + Optional(). + Advanced(). + Description("Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."). + LintRule(`root = if this > 2 || this < -2 { [ "field must be less than 2 and greater than -2" ] }`), + service.NewFloatField(ccpFieldPresencePenalty). + Optional(). + Advanced(). + Description("Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics."). + LintRule(`root = if this > 2 || this < -2 { [ "field must be less than 2 and greater than -2" ] }`), + service.NewIntField(ccpFieldSeed). + Advanced(). + Optional(). + Description("If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed."), + service.NewStringListField(ccpFieldStop). + Optional(). + Advanced(). + Description("Up to 4 sequences where the API will stop generating further tokens."), + ).LintRule(` + root = match { + this.exists("` + ccpFieldJSONSchema + `") && this.exists("` + ccpFieldSchemaRegistry + `") => ["cannot set both ` + "`" + ccpFieldJSONSchema + "`" + ` and ` + "`" + ccpFieldSchemaRegistry + "`" + `"] + this.response_format == "json_schema" && !this.exists("` + ccpFieldJSONSchema + `") && !this.exists("` + ccpFieldSchemaRegistry + `") => ["schema must be specified using either ` + "`" + ccpFieldJSONSchema + "`" + ` or ` + "`" + ccpFieldSchemaRegistry + "`" + `"] + } + `) +} + +func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { + b, err := newBaseProcessor(conf) + if err != nil { + return nil, err + } + var up *service.InterpolatedString + if conf.Contains(ccpFieldUserPrompt) { + up, err = conf.FieldInterpolatedString(ccpFieldUserPrompt) + if err != nil { + return nil, err + } + } + var sp *service.InterpolatedString + if conf.Contains(ccpFieldSystemPrompt) { + sp, err = conf.FieldInterpolatedString(ccpFieldSystemPrompt) + if err != nil { + return nil, err + } + } + var maxTokens *int + if conf.Contains(ccpFieldMaxTokens) { + mt, err := conf.FieldInt(ccpFieldMaxTokens) + if err != nil { + return nil, err + } + maxTokens = &mt + } + var temp *float64 + if conf.Contains(ccpFieldTemp) { + ft, err := conf.FieldFloat(ccpFieldTemp) + if err != nil { + return nil, err + } + temp = &ft + } + var topP *float64 + if conf.Contains(ccpFieldTopP) { + v, err := conf.FieldFloat(ccpFieldTopP) + if err != nil { + return nil, err + } + topP = &v + } + var frequencyPenalty *float64 + if conf.Contains(ccpFieldFrequencyPenalty) { + v, err := conf.FieldFloat(ccpFieldFrequencyPenalty) + if err != nil { + return nil, err + } + frequencyPenalty = &v + } + var presencePenalty *float64 + if conf.Contains(ccpFieldPresencePenalty) { + v, err := conf.FieldFloat(ccpFieldPresencePenalty) + if err != nil { + return nil, err + } + presencePenalty = &v + } + var seed *int + if conf.Contains(ccpFieldSeed) { + intSeed, err := conf.FieldInt(ccpFieldSeed) + if err != nil { + return nil, err + } + seed = &intSeed + } + var stop []string + if conf.Contains(ccpFieldStop) { + stop, err = conf.FieldStringList(ccpFieldStop) + if err != nil { + return nil, err + } + } + v, err := conf.FieldString(ccpFieldResponseFormat) + if err != nil { + return nil, err + } + var responseFormat cohere.ResponseFormat + var schemaProvider jsonSchemaProvider + switch v { + case "json": + fallthrough + case "json_object": + responseFormat.Type = "json_object" + responseFormat.JsonObject = &cohere.JsonResponseFormat{} + case "json_schema": + responseFormat.Type = "json_object" + responseFormat.JsonObject = &cohere.JsonResponseFormat{} + if conf.Contains(ccpFieldJSONSchema) { + schemaProvider, err = newFixedSchemaProvider(conf) + if err != nil { + return nil, err + } + } else if conf.Contains(ccpFieldSchemaRegistry) { + schemaProvider, err = newDynamicSchemaProvider(conf.Namespace(ccpFieldSchemaRegistry), mgr) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("using %s %q, but did not specify %s or %s", ccpFieldResponseFormat, v, ccpFieldJSONSchema, ccpFieldSchemaRegistry) + } + case "text": + responseFormat.Type = "text" + responseFormat.Text = &cohere.TextResponseFormat{} + default: + return nil, fmt.Errorf("unknown %s: %q", ccpFieldResponseFormat, v) + } + return &chatProcessor{b, up, sp, maxTokens, temp, topP, frequencyPenalty, presencePenalty, seed, stop, responseFormat, schemaProvider}, nil +} + +func newFixedSchemaProvider(conf *service.ParsedConfig) (jsonSchemaProvider, error) { + schema, err := conf.FieldString(ccpFieldJSONSchema) + if err != nil { + return nil, err + } + return newFixedSchema(schema) +} + +func newDynamicSchemaProvider(conf *service.ParsedConfig, mgr *service.Resources) (jsonSchemaProvider, error) { + url, err := conf.FieldString(ccpFieldSchemaRegistryURL) + if err != nil { + return nil, err + } + reqSigner, err := conf.HTTPRequestAuthSignerFromParsed() + if err != nil { + return nil, err + } + tlsConfig, err := conf.FieldTLS(ccpFieldSchemaRegistryTLS) + if err != nil { + return nil, err + } + client, err := sr.NewClient(url, reqSigner, tlsConfig, mgr) + if err != nil { + return nil, fmt.Errorf("unable to create schema registry client: %w", err) + } + subject, err := conf.FieldString(ccpFieldSchemaRegistrySubject) + if err != nil { + return nil, err + } + var refreshInterval time.Duration = math.MaxInt64 + if conf.Contains(ccpFieldSchemaRegistryRefreshInterval) { + refreshInterval, err = conf.FieldDuration(ccpFieldSchemaRegistryRefreshInterval) + if err != nil { + return nil, err + } + } + return newDynamicSchema(client, subject, refreshInterval), nil +} + +type chatProcessor struct { + *baseProcessor + + userPrompt *service.InterpolatedString + systemPrompt *service.InterpolatedString + maxTokens *int + temperature *float64 + topP *float64 + frequencyPenalty *float64 + presencePenalty *float64 + seed *int + stop []string + responseFormat cohere.ResponseFormat + schemaProvider jsonSchemaProvider +} + +func (p *chatProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { + var body cohere.ChatRequest + body.Model = &p.model + body.MaxTokens = p.maxTokens + body.Temperature = p.temperature + body.P = p.topP + body.Seed = p.seed + body.FrequencyPenalty = p.frequencyPenalty + body.PresencePenalty = p.presencePenalty + body.ResponseFormat = &p.responseFormat + if p.schemaProvider != nil { + s, err := p.schemaProvider.GetJSONSchema(ctx) + if err != nil { + return nil, err + } + body.ResponseFormat.JsonObject.Schema = s + } + body.StopSequences = p.stop + if p.systemPrompt != nil { + s, err := p.systemPrompt.TryString(msg) + if err != nil { + return nil, fmt.Errorf("%s interpolation error: %w", ccpFieldSystemPrompt, err) + } + body.Preamble = &s + } + if p.userPrompt != nil { + s, err := p.userPrompt.TryString(msg) + if err != nil { + return nil, fmt.Errorf("%s interpolation error: %w", ccpFieldUserPrompt, err) + } + body.Message = s + } else { + b, err := msg.AsBytes() + if err != nil { + return nil, err + } + body.Message = string(b) + } + resp, err := p.client.Chat(ctx, &body) + if err != nil { + return nil, err + } + msg = msg.Copy() + msg.SetBytes([]byte(resp.Text)) + return service.MessageBatch{msg}, nil +} diff --git a/internal/impl/cohere/embeddings_processor.go b/internal/impl/cohere/embeddings_processor.go new file mode 100644 index 000000000..ae959154c --- /dev/null +++ b/internal/impl/cohere/embeddings_processor.go @@ -0,0 +1,162 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package cohere + +import ( + "context" + "errors" + "fmt" + + cohere "github.com/cohere-ai/cohere-go/v2" + "github.com/redpanda-data/benthos/v4/public/bloblang" + "github.com/redpanda-data/benthos/v4/public/service" +) + +const ( + oepFieldTextMapping = "text_mapping" + oepFieldInputType = "dimensions" +) + +func init() { + err := service.RegisterProcessor( + "cohere_embeddings", + embeddingProcessorConfig(), + makeEmbeddingsProcessor, + ) + if err != nil { + panic(err) + } +} + +func embeddingProcessorConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Categories("AI"). + Summary("Generates vector embeddings to represent input text, using the Cohere API."). + Description(` +This processor sends text strings to the Cohere API, which generates vector embeddings. By default, the processor submits the entire payload of each message as a string, unless you use the `+"`"+oepFieldTextMapping+"`"+` configuration field to customize it. + +To learn more about vector embeddings, see the https://docs.cohere.com/docs/embeddings[Cohere API documentation^].`). + Version("4.37.0"). + Fields( + baseConfigFieldsWithModels( + "embed-english-v3.0", + "embed-english-light-v3.0", + "embed-multilingual-v3.0", + "embed-multilingual-light-v3.0", + )..., + ). + Fields( + service.NewBloblangField(oepFieldTextMapping). + Description("The text you want to generate a vector embedding for. By default, the processor submits the entire payload as a string."). + Optional(), + service.NewStringAnnotatedEnumField(oepFieldInputType, map[string]string{ + "search_document": "Used for embeddings stored in a vector database for search use-cases.", + "search_query": "Used for embeddings of search queries run against a vector DB to find relevant documents.", + "classification": "Used for embeddings passed through a text classifier.", + "clustering": "Used for the embeddings run through a clustering algorithm.", + }). + Description("Specifies the type of input passed to the model."). + Default("search_document"), + ). + Example( + "Store embedding vectors in Qdrant", + "Compute embeddings for some generated data and store it within xrefs:component:outputs/qdrant.adoc[Qdrant]", + `input: + generate: + interval: 1s + mapping: | + root = {"text": fake("paragraph")} +pipeline: + processors: + - cohere_embeddings: + model: embed-english-v3 + auth_token: "${COHERE_AUTH_TOKEN}" + text_mapping: "root = this.text" +output: + qdrant: + grpc_host: localhost:6334 + collection_name: "example_collection" + id: "root = uuid_v4()" + vector_mapping: "root = this"`) +} + +func makeEmbeddingsProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { + b, err := newBaseProcessor(conf) + if err != nil { + return nil, err + } + var t *bloblang.Executor + if conf.Contains(oepFieldTextMapping) { + t, err = conf.FieldBloblang(oepFieldTextMapping) + if err != nil { + return nil, err + } + } + var et cohere.EmbedInputType + if conf.Contains(oepFieldInputType) { + v, err := conf.FieldString(oepFieldInputType) + if err != nil { + return nil, err + } + t, err := cohere.NewEmbedInputTypeFromString(v) + if err != nil { + return nil, err + } + et = t + } + return &embeddingsProcessor{b, t, et}, nil +} + +type embeddingsProcessor struct { + *baseProcessor + + text *bloblang.Executor + inputType cohere.EmbedInputType +} + +func (p *embeddingsProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { + var body cohere.EmbedRequest + body.Model = &p.model + body.InputType = &p.inputType + if p.text != nil { + s, err := msg.BloblangQuery(p.text) + if err != nil { + return nil, fmt.Errorf("%s execution error: %w", oepFieldTextMapping, err) + } + r, err := s.AsBytes() + if err != nil { + return nil, fmt.Errorf("%s extraction error: %w", oepFieldTextMapping, err) + } + body.Texts = append(body.Texts, string(r)) + } else { + b, err := msg.AsBytes() + if err != nil { + return nil, err + } + body.Texts = append(body.Texts, string(b)) + } + resp, err := p.client.Embed(ctx, &body) + if err != nil { + return nil, err + } + if resp.EmbeddingsFloats == nil { + return nil, errors.New("expected embeddings output") + } + if len(resp.EmbeddingsFloats.Embeddings) != 1 { + return nil, fmt.Errorf("expected a single embeddings response, got: %d", len(resp.EmbeddingsFloats.Embeddings)) + } + embd := resp.EmbeddingsFloats.Embeddings[0] + data := make([]any, len(embd)) + for i, f := range embd { + data[i] = f + } + msg = msg.Copy() + msg.SetStructuredMut(data) + return service.MessageBatch{msg}, nil +} diff --git a/internal/impl/cohere/json_schema_provider.go b/internal/impl/cohere/json_schema_provider.go new file mode 100644 index 000000000..b8a9e6b02 --- /dev/null +++ b/internal/impl/cohere/json_schema_provider.go @@ -0,0 +1,84 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package cohere + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" +) + +type jsonSchema = map[string]any + +type jsonSchemaProvider interface { + GetJSONSchema(context.Context) (jsonSchema, error) +} + +type fixedSchemaProvider struct { + jsonSchema +} + +func (s *fixedSchemaProvider) GetJSONSchema(context.Context) (jsonSchema, error) { + return s.jsonSchema, nil +} + +func newFixedSchema(raw string) (jsonSchemaProvider, error) { + p := &fixedSchemaProvider{} + if err := json.Unmarshal([]byte(raw), &p.jsonSchema); err != nil { + return nil, fmt.Errorf("invalid JSON schema: %w", err) + } + return p, nil +} + +type dynamicSchemaProvider struct { + cached jsonSchema + nextRefreshTime time.Time + refreshInterval time.Duration + mu sync.Mutex + + client *sr.Client + subject string +} + +func (p *dynamicSchemaProvider) GetJSONSchema(ctx context.Context) (jsonSchema, error) { + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + p.mu.Lock() + defer p.mu.Unlock() + // Double check since we now have the lock that we didn't race with other requests + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + info, err := p.client.GetSchemaBySubjectAndVersion(ctx, p.subject, nil) + if err != nil { + return nil, fmt.Errorf("unable to load latest schema for subject %q: %w", p.subject, err) + } + var schema jsonSchema + if err := json.Unmarshal([]byte(info.Schema), &schema); err != nil { + return nil, fmt.Errorf("unable to parse json schema from schema with ID=%d", info.ID) + } + p.cached = schema + p.nextRefreshTime = time.Now().Add(p.refreshInterval) + return p.cached, nil +} + +func newDynamicSchema(client *sr.Client, subject string, refreshInterval time.Duration) jsonSchemaProvider { + return &dynamicSchemaProvider{ + cached: nil, + nextRefreshTime: time.UnixMilli(0), + refreshInterval: refreshInterval, + client: client, + subject: subject, + } +} diff --git a/internal/plugins/info.csv b/internal/plugins/info.csv index 9fe3a7145..6560443bf 100644 --- a/internal/plugins/info.csv +++ b/internal/plugins/info.csv @@ -47,6 +47,8 @@ cassandra ,output ,cassandra ,0.0.0 ,community catch ,processor ,catch ,0.0.0 ,certified ,n ,y ,y chunker ,scanner ,chunker ,0.0.0 ,certified ,n ,y ,y cockroachdb_changefeed ,input ,cockroachdb_changefeed ,0.0.0 ,community ,n ,n ,n +cohere_chat ,processor ,cohere_chat ,4.37.0 ,enterprise ,n ,y ,y +cohere_embeddings ,processor ,cohere_embeddings ,4.37.0 ,enterprise ,n ,y ,y command ,processor ,command ,4.21.0 ,certified ,n ,n ,n compress ,processor ,compress ,0.0.0 ,certified ,n ,y ,y couchbase ,cache ,Couchbase ,4.12.0 ,community ,n ,n ,n diff --git a/public/components/all/package.go b/public/components/all/package.go index 1f35a9b8f..d950cc3e5 100644 --- a/public/components/all/package.go +++ b/public/components/all/package.go @@ -18,6 +18,7 @@ import ( // Import all enterprise components. _ "github.com/redpanda-data/connect/v4/public/components/aws/enterprise" + _ "github.com/redpanda-data/connect/v4/public/components/cohere" _ "github.com/redpanda-data/connect/v4/public/components/gcp/enterprise" _ "github.com/redpanda-data/connect/v4/public/components/kafka/enterprise" _ "github.com/redpanda-data/connect/v4/public/components/ollama" diff --git a/public/components/cloud/package.go b/public/components/cloud/package.go index 57b629cfd..485a43461 100644 --- a/public/components/cloud/package.go +++ b/public/components/cloud/package.go @@ -18,6 +18,7 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/aws/enterprise" _ "github.com/redpanda-data/connect/v4/public/components/azure" _ "github.com/redpanda-data/connect/v4/public/components/changelog" + _ "github.com/redpanda-data/connect/v4/public/components/cohere" _ "github.com/redpanda-data/connect/v4/public/components/confluent" _ "github.com/redpanda-data/connect/v4/public/components/crypto" _ "github.com/redpanda-data/connect/v4/public/components/dgraph" diff --git a/public/components/cohere/package.go b/public/components/cohere/package.go new file mode 100644 index 000000000..5ec330136 --- /dev/null +++ b/public/components/cohere/package.go @@ -0,0 +1,14 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + +package cohere + +import ( + // Bring in the internal plugin definitions. + _ "github.com/redpanda-data/connect/v4/internal/impl/cohere" +) From ede4b838604d7398935f5349be16ddc43fcbb0d9 Mon Sep 17 00:00:00 2001 From: Ashley Jeffs Date: Fri, 13 Sep 2024 09:49:43 +0100 Subject: [PATCH 2/2] Update CHANGELOG --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b51b67ce..862c7ef39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,14 @@ Changelog All notable changes to this project will be documented in this file. +## 4.37.0 - TBD + +### Added + +- New experimental `gcp_vertex_ai_embeddings` processor. (@rockwotj) +- New experimental `aws_bedrock_embeddings` processor. (@rockwotj) +- New experimental `cohere_chat` and `cohere_embeddings` processors. (@rockwotj) + ## 4.36.0 - 2024-09-11 ### Added