From 33fa6565d98699491cac8cef8ff7d788d9441193 Mon Sep 17 00:00:00 2001 From: Yoshida Hiroshi Date: Fri, 10 Jan 2025 12:48:40 +0900 Subject: [PATCH] =?UTF-8?q?AWS=20SDK=20for=20Go=20v2=20=E5=AF=BE=E5=BF=9C?= =?UTF-8?q?=E3=81=AE=E3=83=8F=E3=83=B3=E3=83=89=E3=83=A9=E3=82=92=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- amazon_transcribe_v2.go | 169 ++++++++++++++++++ amazon_transcribe_v2_handler.go | 296 ++++++++++++++++++++++++++++++++ go.mod | 15 ++ go.sum | 30 ++++ languages.go | 9 + 5 files changed, 519 insertions(+) create mode 100644 amazon_transcribe_v2.go create mode 100644 amazon_transcribe_v2_handler.go diff --git a/amazon_transcribe_v2.go b/amazon_transcribe_v2.go new file mode 100644 index 0000000..865c8ca --- /dev/null +++ b/amazon_transcribe_v2.go @@ -0,0 +1,169 @@ +package suzu + +import ( + "context" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/transcribestreaming" + "github.com/aws/aws-sdk-go-v2/service/transcribestreaming/types" + + zlog "github.com/rs/zerolog/log" +) + +type AmazonTranscribeV2 struct { + LanguageCode string + MediaEncoding types.MediaEncoding + MediaSampleRateHertz int64 + EnablePartialResultsStabilization bool + NumberOfChannels int64 + EnableChannelIdentification bool + PartialResultsStability string + Region string + Debug bool + Config Config +} + +func NewAmazonTranscribeV2(c Config, languageCode string, sampleRateHertz, audioChannelCount int64) *AmazonTranscribeV2 { + return &AmazonTranscribeV2{ + Region: c.AwsRegion, + LanguageCode: languageCode, + MediaEncoding: types.MediaEncodingOggOpus, + MediaSampleRateHertz: sampleRateHertz, + EnablePartialResultsStabilization: c.AwsEnablePartialResultsStabilization, + PartialResultsStability: c.AwsPartialResultsStability, + NumberOfChannels: audioChannelCount, + EnableChannelIdentification: c.AwsEnableChannelIdentification, + Config: c, + } +} + +func NewStartStreamTranscriptionInputV2(at *AmazonTranscribeV2) transcribestreaming.StartStreamTranscriptionInput { + var numberOfChannels *int32 + if at.EnableChannelIdentification { + c := int32(at.NumberOfChannels) + numberOfChannels = &c + } + + sampleRateHertz := int32(at.MediaSampleRateHertz) + + if !at.EnablePartialResultsStabilization { + return transcribestreaming.StartStreamTranscriptionInput{ + LanguageCode: types.LanguageCode(at.LanguageCode), + MediaEncoding: at.MediaEncoding, + MediaSampleRateHertz: &sampleRateHertz, + NumberOfChannels: numberOfChannels, + EnablePartialResultsStabilization: at.EnablePartialResultsStabilization, + EnableChannelIdentification: at.EnableChannelIdentification, + } + } else { + return transcribestreaming.StartStreamTranscriptionInput{ + LanguageCode: types.LanguageCode(at.LanguageCode), + MediaEncoding: at.MediaEncoding, + MediaSampleRateHertz: &sampleRateHertz, + NumberOfChannels: numberOfChannels, + EnablePartialResultsStabilization: at.EnablePartialResultsStabilization, + PartialResultsStability: types.PartialResultsStability(at.PartialResultsStability), + EnableChannelIdentification: at.EnableChannelIdentification, + } + } +} + +func NewAmazonTranscribeClientV2(c Config) (*transcribestreaming.Client, error) { + // TODO: 後で変更する + tr := &http.Transport{} + httpClient := &http.Client{Transport: tr} + + ctx := context.TODO() + + var cfg aws.Config + if c.AwsProfile != "" { + // TODO: logLevel の指定 + var err error + cfg, err = config.LoadDefaultConfig(ctx, + config.WithRegion(c.AwsRegion), + config.WithSharedConfigProfile(c.AwsProfile), + config.WithSharedCredentialsFiles([]string{c.AwsCredentialFile}), + config.WithHTTPClient(httpClient), + ) + if err != nil { + return nil, err + } + } else { + var err error + cfg, err = config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + } + + client := transcribestreaming.NewFromConfig(cfg) + return client, nil +} + +func (at *AmazonTranscribeV2) Start(ctx context.Context, r io.ReadCloser) (*transcribestreaming.StartStreamTranscriptionEventStream, error) { + config := at.Config + client, err := NewAmazonTranscribeClientV2(config) + if err != nil { + return nil, err + } + input := NewStartStreamTranscriptionInputV2(at) + + resp, err := client.StartStreamTranscription(ctx, &input) + if err != nil { + // TODO: v2 には存在しないため、変更されたエラーに置き換える + // if reqErr, ok := err.(awserr.RequestFailure); ok { + // code := reqErr.StatusCode() + // message := reqErr.Message() + + // var retry bool + // if code == http.StatusTooManyRequests { + // retry = true + // } + + // return nil, &SuzuError{ + // Code: code, + // Message: message, + // Retry: retry, + // } + // } + return nil, err + } + + stream := resp.GetStream() + + go func() { + defer r.Close() + defer func() { + if err := stream.Close(); err != nil { + zlog.Error().Err(err).Send() + } + }() + + frame := make([]byte, FrameSize) + for { + n, err := r.Read(frame) + if err != nil { + if err != io.EOF { + zlog.Error().Err(err).Send() + } + break + } + if n > 0 { + err := stream.Send(ctx, &types.AudioStreamMemberAudioEvent{ + Value: types.AudioEvent{ + AudioChunk: frame[:n], + }, + }) + if err != nil { + zlog.Error().Err(err).Send() + break + } + } + } + }() + + return stream, nil +} diff --git a/amazon_transcribe_v2_handler.go b/amazon_transcribe_v2_handler.go new file mode 100644 index 0000000..e59d155 --- /dev/null +++ b/amazon_transcribe_v2_handler.go @@ -0,0 +1,296 @@ +package suzu + +import ( + "context" + "encoding/json" + "errors" + "io" + "strings" + "sync" + + "github.com/aws/aws-sdk-go-v2/service/transcribestreaming/types" + zlog "github.com/rs/zerolog/log" +) + +func init() { + NewServiceHandlerFuncs.register("awsv2", NewAmazonTranscribeV2Handler) +} + +type AmazonTranscribeV2Handler struct { + Config Config + + ChannelID string + ConnectionID string + SampleRate uint32 + ChannelCount uint16 + LanguageCode string + RetryCount int + mu sync.Mutex + + OnResultFunc func(context.Context, io.WriteCloser, string, string, string, any) error +} + +func NewAmazonTranscribeV2Handler(config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) serviceHandlerInterface { + return &AmazonTranscribeV2Handler{ + Config: config, + ChannelID: channelID, + ConnectionID: connectionID, + SampleRate: sampleRate, + ChannelCount: channelCount, + LanguageCode: languageCode, + RetryCount: 0, + OnResultFunc: onResultFunc.(func(context.Context, io.WriteCloser, string, string, string, any) error), + } +} + +type AwsResultV2 struct { + ChannelID *string `json:"channel_id,omitempty"` + IsPartial *bool `json:"is_partial,omitempty"` + ResultID *string `json:"result_id,omitempty"` + TranscriptionResult +} + +func NewAwsResultV2() AwsResultV2 { + return AwsResultV2{ + TranscriptionResult: TranscriptionResult{ + Type: "aws", + }, + } +} + +func (ar *AwsResultV2) WithChannelID(channelID string) *AwsResultV2 { + ar.ChannelID = &channelID + return ar +} + +func (ar *AwsResultV2) WithIsPartial(isPartial bool) *AwsResultV2 { + ar.IsPartial = &isPartial + return ar +} + +func (ar *AwsResultV2) WithResultID(resultID string) *AwsResultV2 { + ar.ResultID = &resultID + return ar +} + +func (ar *AwsResultV2) SetMessage(message string) *AwsResultV2 { + ar.Message = message + return ar +} + +func (h *AmazonTranscribeV2Handler) UpdateRetryCount() int { + defer h.mu.Unlock() + h.mu.Lock() + h.RetryCount++ + return h.RetryCount +} + +func (h *AmazonTranscribeV2Handler) GetRetryCount() int { + return h.RetryCount +} + +func (h *AmazonTranscribeV2Handler) ResetRetryCount() int { + defer h.mu.Unlock() + h.mu.Lock() + h.RetryCount = 0 + return h.RetryCount +} + +func (h *AmazonTranscribeV2Handler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) { + at := NewAmazonTranscribeV2(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount)) + + packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header) + if err != nil { + return nil, err + } + + stream, err := at.Start(ctx, packetReader) + if err != nil { + return nil, err + } + + // リクエストが成功した時点でリトライカウントをリセットする + h.ResetRetryCount() + + r, w := io.Pipe() + + go func() { + encoder := json.NewEncoder(w) + + L: + for { + select { + case <-ctx.Done(): + break L + case event := <-stream.Events(): + switch e := event.(type) { + case *types.TranscriptResultStreamMemberTranscriptEvent: + if h.OnResultFunc != nil { + if err := h.OnResultFunc(ctx, w, h.ChannelID, h.ConnectionID, h.LanguageCode, e.Value.Transcript.Results); err != nil { + if err := encoder.Encode(NewSuzuErrorResponse(err)); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Send() + } + w.CloseWithError(err) + return + } + } else { + for _, res := range e.Value.Transcript.Results { + if at.Config.FinalResultOnly { + // IsPartial: true の場合は結果を返さない + if res.IsPartial { + continue + } + } + + result := NewAwsResult() + if at.Config.AwsResultIsPartial { + result.WithIsPartial(res.IsPartial) + } + if at.Config.AwsResultChannelID { + result.WithChannelID(*res.ChannelId) + } + if at.Config.AwsResultID { + result.WithResultID(*res.ResultId) + } + + for _, alt := range res.Alternatives { + message, ok := buildMessageV2(at.Config, alt, res.IsPartial) + if !ok { + continue + } + + result.SetMessage(message) + if err := encoder.Encode(result); err != nil { + w.CloseWithError(err) + return + } + } + } + } + default: + break L + } + } + } + + if err := stream.Err(); err != nil { + zlog.Error(). + Err(err). + Str("channel_id", h.ChannelID). + Str("connection_id", h.ConnectionID). + Int("retry_count", h.GetRetryCount()). + Send() + + // 復帰が不可能なエラー以外は再接続を試みる + switch err.(type) { + case *types.LimitExceededException, + *types.InternalFailureException: + err = errors.Join(err, ErrServerDisconnected) + default: + // サーバから切断された場合は再接続を試みる + if strings.Contains(err.Error(), "http2: server sent GOAWAY and closed the connection;") { + err = errors.Join(err, ErrServerDisconnected) + } + } + + w.CloseWithError(err) + return + } + w.Close() + }() + + return r, nil +} + +func contentFilterByTranscribedTimeV2(config Config, item types.Item) bool { + minimumTranscribedTime := config.MinimumTranscribedTime + + // minimumTranscribedTime が設定されていない場合はフィルタリングしない + if minimumTranscribedTime <= 0 { + return true + } + + // 句読点の場合はフィルタリングしない + if item.Type == types.ItemTypePunctuation { + return true + } + + // TODO: v2 で必須になったため、対応方法を検討する + // // StartTime または EndTime が nil の場合はフィルタリングしない + // if (item.StartTime == nil) || (item.EndTime == nil) { + // return true + // } + + // 発話時間が minimumTranscribedTime 未満の場合はフィルタリングする + return (item.EndTime - item.StartTime) >= minimumTranscribedTime +} + +func contentFilterByConfidenceScoreV2(config Config, item types.Item, isPartial bool) bool { + minimumConfidenceScore := config.MinimumConfidenceScore + + // minimumConfidenceScore が設定されていない場合はフィルタリングしない + if minimumConfidenceScore <= 0 { + return true + } + + // isPartial が true の場合はフィルタリングしない + if isPartial { + return true + } + + // 句読点の場合はフィルタリングしない + if item.Type == types.ItemTypePunctuation { + return true + } + + // Confidence が nil の場合はフィルタリングしない + if item.Confidence == nil { + return true + } + + // 信頼スコアが minimumConfidenceScore 未満の場合はフィルタリングする + return *item.Confidence >= minimumConfidenceScore +} + +func buildMessageV2(config Config, alt types.Alternative, isPartial bool) (string, bool) { + var message string + + minimumTranscribedTime := config.MinimumTranscribedTime + minimumConfidenceScore := config.MinimumConfidenceScore + + // 両方無効の場合には全てのメッセージを返す + if (minimumTranscribedTime <= 0) && (minimumConfidenceScore <= 0) { + return *alt.Transcript, true + } + + items := alt.Items + + includePronunciation := false + + for _, item := range items { + if !contentFilterByTranscribedTimeV2(config, item) { + continue + } + + if !contentFilterByConfidenceScoreV2(config, item, isPartial) { + continue + } + + if item.Type == types.ItemTypePunctuation { + includePronunciation = true + } + + message += *item.Content + } + + // 各評価の結果、句読点のみかメッセージが空の場合は次へ + if !includePronunciation || (message == "") { + return "", false + } + + return message, true +} diff --git a/go.mod b/go.mod index 5a32476..44ef5c0 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,21 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.6.3 // indirect + github.com/aws/aws-sdk-go-v2 v1.32.7 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/config v1.28.8 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.49 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.4 // indirect + github.com/aws/aws-sdk-go-v2/service/transcribestreaming v1.22.6 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 3f39da6..f700236 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,36 @@ github.com/aws/aws-sdk-go v1.51.30 h1:RVFkjn9P0JMwnuZCVH0TlV5k9zepHzlbc4943eZMhG github.com/aws/aws-sdk-go v1.51.30/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.28.8 h1:4nUeC9TsZoHm9GHlQ5tnoIklNZgISXXVGPKP5/CS0fk= +github.com/aws/aws-sdk-go-v2/config v1.28.8/go.mod h1:2C+fhFxnx1ymomFjj5NBUc/vbjyIUR7mZ/iNRhhb7BU= +github.com/aws/aws-sdk-go-v2/credentials v1.17.49 h1:+7u6eC8K6LLGQwWMYKHSsHAPQl+CGACQmnzd/EPMW0k= +github.com/aws/aws-sdk-go-v2/credentials v1.17.49/go.mod h1:0SgZcTAEIlKoYw9g+kuYUwbtUUVjfxnR03YkCOhMbQ0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 h1:zXFLuEuMMUOvEARXFUVJdfqZ4bvvSgdGRq/ATcrQxzM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26/go.mod h1:3o2Wpy0bogG1kyOPrgkXA8pgIfEEv0+m19O9D5+W8y8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7 h1:8eUsivBQzZHqe/3FE+cqwfH+0p5Jo8PFM/QYQSmeZ+M= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7/go.mod h1:kLPQvGUmxn/fqiCrDeohwG33bq2pQpGeY62yRO6Nrh0= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 h1:CvuUmnXI7ebaUAhbJcDy9YQx8wHR69eZ9I7q5hszt/g= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.8/go.mod h1:XDeGv1opzwm8ubxddF0cgqkZWsyOtw4lr6dxwmb6YQg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7 h1:F2rBfNAL5UyswqoeWv9zs74N/NanhK16ydHW1pahX6E= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7/go.mod h1:JfyQ0g2JG8+Krq0EuZNnRwX0mU0HrwY/tG6JNfcqh4k= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.4 h1:EzofOvWNMtG9ELt9mPOJjLYh1hz6kN4f5hNCyTtS7Hg= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.4/go.mod h1:5Gn+d+VaaRgsjewpMvGazt0WfcFO+Md4wLOuBfGR9Bc= +github.com/aws/aws-sdk-go-v2/service/transcribestreaming v1.22.6 h1:dFMFmYerXXvx/1tgBZ1UtSm4TjMccaOH/cDfcoMXrqo= +github.com/aws/aws-sdk-go-v2/service/transcribestreaming v1.22.6/go.mod h1:AO9mLt2KiIYMG7jF3OS01h7SVzYIr+5mY66LDE//3no= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/languages.go b/languages.go index 697d491..59d6eb2 100644 --- a/languages.go +++ b/languages.go @@ -3,6 +3,7 @@ package suzu import ( "fmt" + "github.com/aws/aws-sdk-go-v2/service/transcribestreaming/types" "github.com/aws/aws-sdk-go/service/transcribestreamingservice" ) @@ -29,6 +30,14 @@ func GetLanguageCode(serviceType, lang string, f func(string) (string, error)) ( } } return "", fmt.Errorf("%w: %s", ErrUnsupportedLanguageCode, lang) + case "awsv2": + lc := new(types.LanguageCode) + for _, languageCode := range lc.Values() { + if languageCode == types.LanguageCode(lang) { + return lang, nil + } + } + return "", fmt.Errorf("%w: %s", ErrUnsupportedLanguageCode, lang) case "gcp", "test", "dump": return lang, nil }