From 0559fc9f5e46a38846dafba6d7f6b21c2f3df5f5 Mon Sep 17 00:00:00 2001 From: Kevin Sylvestre Date: Fri, 9 Aug 2024 10:55:28 -0700 Subject: [PATCH] Properly implement serialize / deserialize --- Gemfile | 1 + Gemfile.lock | 38 +++++-- README.md | 17 +--- lib/omniai/chat.rb | 44 ++++----- lib/omniai/chat/choice.rb | 55 +++++++++++ lib/omniai/chat/content.rb | 12 ++- lib/omniai/chat/function.rb | 53 ++++++++++ lib/omniai/chat/message.rb | 15 ++- lib/omniai/chat/payload.rb | 82 ++++++++++++++++ lib/omniai/chat/prompt.rb | 5 + lib/omniai/chat/response.rb | 55 +++++++++++ lib/omniai/chat/response/choice.rb | 35 ------- lib/omniai/chat/response/chunk.rb | 15 --- lib/omniai/chat/response/completion.rb | 15 --- lib/omniai/chat/response/delta.rb | 11 --- lib/omniai/chat/response/delta_choice.rb | 25 ----- lib/omniai/chat/response/function.rb | 25 ----- lib/omniai/chat/response/message.rb | 11 --- lib/omniai/chat/response/message_choice.rb | 25 ----- lib/omniai/chat/response/part.rb | 38 ------- lib/omniai/chat/response/payload.rb | 72 -------------- lib/omniai/chat/response/resource.rb | 22 ----- lib/omniai/chat/response/stream.rb | 27 ----- lib/omniai/chat/response/tool_call.rb | 30 ------ lib/omniai/chat/response/usage.rb | 35 ------- lib/omniai/chat/stream.rb | 33 +++++++ lib/omniai/chat/tool_call.rb | 54 ++++++++++ lib/omniai/chat/tool_call_batch.rb | 15 +++ lib/omniai/chat/tool_message.rb | 59 +++++++++++ lib/omniai/chat/usage.rb | 60 ++++++++++++ lib/omniai/version.rb | 2 +- spec/factories/chat/choice.rb | 10 ++ spec/factories/chat/content.rb | 7 ++ spec/factories/chat/file.rb | 10 ++ spec/factories/chat/function.rb | 10 ++ spec/factories/chat/media.rb | 9 ++ spec/factories/chat/message.rb | 10 ++ spec/factories/chat/prompt.rb | 9 ++ spec/factories/chat/text.rb | 9 ++ spec/factories/chat/tool_call.rb | 9 ++ spec/factories/chat/url.rb | 10 ++ spec/factories/chat/usage.rb | 11 +++ spec/factories/embed/usage.rb | 10 ++ spec/omniai/chat/choice_spec.rb | 70 +++++++++++++ spec/omniai/chat/content_spec.rb | 2 +- spec/omniai/chat/file_spec.rb | 2 +- spec/omniai/chat/function_spec.rb | 72 ++++++++++++++ spec/omniai/chat/media_spec.rb | 2 +- spec/omniai/chat/message_spec.rb | 2 +- spec/omniai/chat/payload_spec.rb | 72 ++++++++++++++ spec/omniai/chat/prompt_spec.rb | 2 +- spec/omniai/chat/response/choice_spec.rb | 19 ---- spec/omniai/chat/response/chunk_spec.rb | 39 -------- spec/omniai/chat/response/completion_spec.rb | 48 --------- .../omniai/chat/response/delta_choice_spec.rb | 22 ----- spec/omniai/chat/response/delta_spec.rb | 19 ---- spec/omniai/chat/response/function_spec.rb | 16 --- .../chat/response/message_choice_spec.rb | 22 ----- spec/omniai/chat/response/message_spec.rb | 19 ---- spec/omniai/chat/response/part_spec.rb | 38 ------- spec/omniai/chat/response/payload_spec.rb | 62 ------------ spec/omniai/chat/response/resource_spec.rb | 17 ---- spec/omniai/chat/response/tool_call_spec.rb | 23 ----- spec/omniai/chat/response/usage_spec.rb | 38 ------- spec/omniai/chat/text_spec.rb | 2 +- spec/omniai/chat/tool_call_spec.rb | 98 +++++++++++++++++++ spec/omniai/chat/url_spec.rb | 2 +- spec/omniai/chat/usage_spec.rb | 80 +++++++++++++++ spec/omniai/chat_spec.rb | 7 +- spec/omniai/embed/usage_spec.rb | 2 +- spec/spec_helper.rb | 2 + spec/support/factory_bot.rb | 9 ++ 72 files changed, 1077 insertions(+), 831 deletions(-) create mode 100644 lib/omniai/chat/choice.rb create mode 100644 lib/omniai/chat/function.rb create mode 100644 lib/omniai/chat/payload.rb create mode 100644 lib/omniai/chat/response.rb delete mode 100644 lib/omniai/chat/response/choice.rb delete mode 100644 lib/omniai/chat/response/chunk.rb delete mode 100644 lib/omniai/chat/response/completion.rb delete mode 100644 lib/omniai/chat/response/delta.rb delete mode 100644 lib/omniai/chat/response/delta_choice.rb delete mode 100644 lib/omniai/chat/response/function.rb delete mode 100644 lib/omniai/chat/response/message.rb delete mode 100644 lib/omniai/chat/response/message_choice.rb delete mode 100644 lib/omniai/chat/response/part.rb delete mode 100644 lib/omniai/chat/response/payload.rb delete mode 100644 lib/omniai/chat/response/resource.rb delete mode 100644 lib/omniai/chat/response/stream.rb delete mode 100644 lib/omniai/chat/response/tool_call.rb delete mode 100644 lib/omniai/chat/response/usage.rb create mode 100644 lib/omniai/chat/stream.rb create mode 100644 lib/omniai/chat/tool_call.rb create mode 100644 lib/omniai/chat/tool_call_batch.rb create mode 100644 lib/omniai/chat/tool_message.rb create mode 100644 lib/omniai/chat/usage.rb create mode 100644 spec/factories/chat/choice.rb create mode 100644 spec/factories/chat/content.rb create mode 100644 spec/factories/chat/file.rb create mode 100644 spec/factories/chat/function.rb create mode 100644 spec/factories/chat/media.rb create mode 100644 spec/factories/chat/message.rb create mode 100644 spec/factories/chat/prompt.rb create mode 100644 spec/factories/chat/text.rb create mode 100644 spec/factories/chat/tool_call.rb create mode 100644 spec/factories/chat/url.rb create mode 100644 spec/factories/chat/usage.rb create mode 100644 spec/factories/embed/usage.rb create mode 100644 spec/omniai/chat/choice_spec.rb create mode 100644 spec/omniai/chat/function_spec.rb create mode 100644 spec/omniai/chat/payload_spec.rb delete mode 100644 spec/omniai/chat/response/choice_spec.rb delete mode 100644 spec/omniai/chat/response/chunk_spec.rb delete mode 100644 spec/omniai/chat/response/completion_spec.rb delete mode 100644 spec/omniai/chat/response/delta_choice_spec.rb delete mode 100644 spec/omniai/chat/response/delta_spec.rb delete mode 100644 spec/omniai/chat/response/function_spec.rb delete mode 100644 spec/omniai/chat/response/message_choice_spec.rb delete mode 100644 spec/omniai/chat/response/message_spec.rb delete mode 100644 spec/omniai/chat/response/part_spec.rb delete mode 100644 spec/omniai/chat/response/payload_spec.rb delete mode 100644 spec/omniai/chat/response/resource_spec.rb delete mode 100644 spec/omniai/chat/response/tool_call_spec.rb delete mode 100644 spec/omniai/chat/response/usage_spec.rb create mode 100644 spec/omniai/chat/tool_call_spec.rb create mode 100644 spec/omniai/chat/usage_spec.rb create mode 100644 spec/support/factory_bot.rb diff --git a/Gemfile b/Gemfile index 7b4ebc4..6cdb126 100644 --- a/Gemfile +++ b/Gemfile @@ -4,6 +4,7 @@ source 'https://rubygems.org' gemspec +gem 'factory_bot' gem 'logger' gem 'rake' gem 'rspec' diff --git a/Gemfile.lock b/Gemfile.lock index 96c4381..253fa38 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - omniai (1.7.0) + omniai (1.8.0) event_stream_parser http zeitwerk @@ -9,18 +9,33 @@ PATH GEM remote: https://rubygems.org/ specs: + activesupport (7.1.3.4) + base64 + bigdecimal + concurrent-ruby (~> 1.0, >= 1.0.2) + connection_pool (>= 2.2.5) + drb + i18n (>= 1.6, < 2) + minitest (>= 5.1) + mutex_m + tzinfo (~> 2.0) addressable (2.8.7) public_suffix (>= 2.0.2, < 7.0) ast (2.4.2) base64 (0.2.0) bigdecimal (3.1.8) + concurrent-ruby (1.3.3) + connection_pool (2.4.1) crack (1.0.0) bigdecimal rexml diff-lcs (1.5.1) - docile (1.4.0) + docile (1.4.1) domain_name (0.6.20240107) + drb (2.2.1) event_stream_parser (1.0.0) + factory_bot (6.4.6) + activesupport (>= 5.0.0) ffi (1.17.0) ffi (1.17.0-aarch64-linux-gnu) ffi (1.17.0-aarch64-linux-musl) @@ -35,7 +50,7 @@ GEM ffi-compiler (1.3.2) ffi (>= 1.15.5) rake - hashdiff (1.1.0) + hashdiff (1.1.1) http (5.2.0) addressable (~> 2.8) base64 (~> 0.1) @@ -45,17 +60,21 @@ GEM http-cookie (1.0.6) domain_name (~> 0.5) http-form_data (2.3.0) + i18n (1.14.5) + concurrent-ruby (~> 1.0) json (2.7.2) language_server-protocol (3.17.0.3) llhttp-ffi (0.5.0) ffi-compiler (~> 1.0) rake (~> 13.0) logger (1.6.0) - parallel (1.25.1) - parser (3.3.4.0) + minitest (5.24.1) + mutex_m (0.2.0) + parallel (1.26.1) + parser (3.3.4.2) ast (~> 2.4.1) racc - public_suffix (6.0.0) + public_suffix (6.0.1) racc (1.8.1) rainbow (3.1.1) rake (13.2.1) @@ -88,11 +107,11 @@ GEM rubocop-ast (>= 1.31.1, < 2.0) ruby-progressbar (~> 1.7) unicode-display_width (>= 2.4.0, < 3.0) - rubocop-ast (1.31.3) + rubocop-ast (1.32.0) parser (>= 3.3.1.0) rubocop-rake (0.6.0) rubocop (~> 1.0) - rubocop-rspec (3.0.3) + rubocop-rspec (3.0.4) rubocop (~> 1.61) ruby-progressbar (1.13.0) simplecov (0.22.0) @@ -102,6 +121,8 @@ GEM simplecov-html (0.12.3) simplecov_json_formatter (0.1.4) strscan (3.1.0) + tzinfo (2.0.6) + concurrent-ruby (~> 1.0) unicode-display_width (2.5.0) webmock (3.23.1) addressable (>= 2.8.0) @@ -124,6 +145,7 @@ PLATFORMS x86_64-linux-musl DEPENDENCIES + factory_bot logger omniai! rake diff --git a/README.md b/README.md index bebe5fd..728dfda 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,8 @@ Clients that support chat (e.g. Anthropic w/ "Claude", Google w/ "Gemini", Mistr Generating a completion is as simple as sending in the text: ```ruby -completion = client.chat('Tell me a joke.') -completion.choice.message.content # 'Why don't scientists trust atoms? They make up everything!' +response = client.chat('Tell me a joke.') +response.content # 'Why don't scientists trust atoms? They make up everything!' ``` #### Completions using a Complex Prompt @@ -136,7 +136,7 @@ completion.choice.message.content # 'Why don't scientists trust atoms? They make More complex completions are generated using a block w/ various system / user messages: ```ruby -completion = client.chat do |prompt| +response = client.chat do |prompt| prompt.system 'You are a helpful assistant with an expertise in animals.' prompt.user do |message| message.text 'What animals are in the attached photos?' @@ -145,7 +145,7 @@ completion = client.chat do |prompt| message.file('./hamster.jpeg', "image/jpeg") end end -completion.choice.message.content # 'They are photos of a cat, a cat, and a hamster.' +response.content # 'They are photos of a cat, a cat, and a hamster.' ``` #### Completions using Streaming via Proc @@ -154,7 +154,7 @@ A real-time stream of messages can be generated by passing in a proc: ```ruby stream = proc do |chunk| - print(chunk.choice.delta.content) # '...' + print(chunk.content) # '...' end client.chat('Tell me a joke.', stream:) ``` @@ -315,10 +315,3 @@ Type 'exit' or 'quit' to abort. 0.0 ... ``` - -0.0 -... - -``` - -``` diff --git a/lib/omniai/chat.rb b/lib/omniai/chat.rb index 4ab8b45..1caf035 100644 --- a/lib/omniai/chat.rb +++ b/lib/omniai/chat.rb @@ -87,6 +87,13 @@ def process! protected + # Override to provide an context for serializers / deserializes for a provider. + # + # @return [Context, nil] + def context + nil + end + # Used to spawn another chat with the same configuration using different messages. # # @param prompt [OmniAI::Chat::Prompt] @@ -114,7 +121,7 @@ def path end # @param response [HTTP::Response] - # @return [OmniAI::Chat::Response::Completion] + # @return [OmniAI::Chat::Response] def parse!(response:) if @stream stream!(response:) @@ -124,27 +131,28 @@ def parse!(response:) end # @param response [HTTP::Response] - # @return [OmniAI::Chat::Response::Completion] + # @return [OmniAI::Chat::Response] def complete!(response:) - completion = self.class::Response::Completion.new(data: response.parse) + completion = Response.new(data: response.parse, context:) if @tools && completion.tool_call_list.any? - spawn!([ - *@prompt.serialize, - *completion.choices.map(&:message).map(&:data), - *(completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) }), - ]) + spawn!( + @prompt.dup.tap do |prompt| + prompt.messages += completion.messages + prompt.messages += completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) } + end + ) else completion end end # @param response [HTTP::Response] - # @return [OmniAI::Chat::Response::Stream] + # @return [OmniAI::Chat::Stream] def stream!(response:) raise Error, "#{self.class.name}#stream! unstreamable" unless @stream - self.class::Response::Stream.new(response:).stream! do |chunk| + Stream.new(body: response.body, context:).stream! do |chunk| case @stream when IO, StringIO if chunk.content? @@ -167,24 +175,14 @@ def request! end # @param tool_call [OmniAI::Chat::ToolCall] + # @return [ToolMessage] def execute_tool_call(tool_call) function = tool_call.function tool = @tools.find { |entry| function.name == entry.name } || raise(ToolCallLookupError, tool_call) - result = tool.call(function.arguments) + content = tool.call(function.arguments) - prepare_tool_call_message(tool_call:, content: result) - end - - # @param tool_call [OmniAI::Chat::ToolCall] - # @param content [String] - def prepare_tool_call_message(tool_call:, content:) - { - role: Role::TOOL, - name: tool_call.function.name, - tool_call_id: tool_call.id, - content:, - } + ToolMessage.new(tool_call:, content:) end end end diff --git a/lib/omniai/chat/choice.rb b/lib/omniai/chat/choice.rb new file mode 100644 index 0000000..e817719 --- /dev/null +++ b/lib/omniai/chat/choice.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # For for. + class Choice + # @return [Integer] + attr_accessor :index + + # @return [Message] + attr_accessor :message + + # @param message [Message] + def initialize(message:, index: 0) + @message = message + @index = index + end + + # @return [String] + def inspect + "#<#{self.class.name} index=#{@index} message=#{@message.inspect}>" + end + + # @param data [Hash] + # @param context [OmniAI::Context] optional + # + # @return [Choice] + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:choice) + return deserialize.call(data, context:) if deserialize + + index = data['index'] + message = Message.deserialize(data['message'] || data['delta'], context:) + new(message:, index:) + end + + # @param context [OmniAI::Context] optional + # @return [Hash] + def serialize(context: nil) + serialize = context&.serializers&.[](:choice) + return serialize.call(self, context:) if serialize + + { + index:, + message: message.serialize(context:), + } + end + + # @return [Array, String] + def content + message.content + end + end + end +end diff --git a/lib/omniai/chat/content.rb b/lib/omniai/chat/content.rb index b24b145..373cf04 100644 --- a/lib/omniai/chat/content.rb +++ b/lib/omniai/chat/content.rb @@ -16,20 +16,24 @@ def self.summarize(content) # # @return [String] def serialize(context: nil) - raise NotImplementedError, ' # {self.class}#serialize undefined' + raise NotImplementedError, "#{self.class}#serialize undefined" end - # @param data [hash] + # @param data [Hash, Array, String] # @param context [Context] optional # # @return [Content] def self.deserialize(data, context: nil) - raise ArgumentError, "untyped data=#{data.inspect}" unless data.key?('type') + return data.map { |data| deserialize(data, context:) } if data.is_a?(Array) + + deserialize = context&.deserializers&.[](:content) + return deserialize.call(data, context:) if deserialize + + return data if data.is_a?(String) case data['type'] when 'text' then Text.deserialize(data, context:) when /(.*)_url/ then URL.deserialize(data, context:) - else raise ArgumentError, "unknown type=#{data['type'].inspect}" end end end diff --git a/lib/omniai/chat/function.rb b/lib/omniai/chat/function.rb new file mode 100644 index 0000000..c9090cb --- /dev/null +++ b/lib/omniai/chat/function.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # A function that includes a name / arguments. + class Function + # @return [String] + attr_accessor :name + + # @return [Hash] + attr_accessor :arguments + + # @param name [String] + # @param arguments [Hash] + def initialize(name:, arguments:) + @name = name + @arguments = arguments + end + + # @return [String] + def inspect + "#<#{self.class.name} name=#{name.inspect} arguments=#{arguments.inspect}>" + end + + # @param data [Hash] + # @param context [Context] optional + # + # @return [Function] + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:function) + return deserialize.call(data, context:) if deserialize + + name = data['name'] + arguments = JSON.parse(data['arguments']) if data['arguments'] + + new(name:, arguments:) + end + + # @param context [Context] optional + # + # @return [Hash] + def serialize(context: nil) + serializer = context&.serializers&.[](:function) + return serializer.call(self, context:) if serializer + + { + name: @name, + arguments: (JSON.generate(@arguments) if @arguments), + } + end + end + end +end diff --git a/lib/omniai/chat/message.rb b/lib/omniai/chat/message.rb index ab061ea..e5881d5 100644 --- a/lib/omniai/chat/message.rb +++ b/lib/omniai/chat/message.rb @@ -10,6 +10,7 @@ class Chat # message.url 'https://example.com/cat.jpg', type: "image/jpeg" # message.url 'https://example.com/dog.jpg', type: "image/jpeg" # message.file File.open('hamster.jpg'), type: "image/jpeg" + # message. # end # end class Message @@ -21,6 +22,7 @@ class Message # @param content [String, nil] # @param role [String] + # @param tool_call_collection [ToolCallBatch, nil] def initialize(content: nil, role: Role::USER) @content = content || [] @role = role @@ -51,10 +53,10 @@ def self.deserialize(data, context: nil) deserialize = context&.deserializers&.[](:message) return deserialize.call(data, context:) if deserialize - new( - content: data['content'].map { |content| Content.deserialize(content, context:) }, - role: data['role'] - ) + role = data['role'] + content = Content.deserialize(data['content'], context:) + + new(content:, role:) end # Usage: @@ -89,6 +91,11 @@ def user? role?(Role::USER) end + # @return [Boolean] + def content? + !@content.nil? + end + # Usage: # # message.text('What are these photos of?') diff --git a/lib/omniai/chat/payload.rb b/lib/omniai/chat/payload.rb new file mode 100644 index 0000000..6326e78 --- /dev/null +++ b/lib/omniai/chat/payload.rb @@ -0,0 +1,82 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # A chunk or completion. + class Payload + # @return [Array] + attr_accessor :choices + + # @return [Usage, nil] + attr_accessor :usage + + # @param choices [Array] + # @param usage [Usage, nil] + def initialize(choices:, usage: nil) + @choices = choices + @usage = usage + end + + # @return [String] + def inspect + "#<#{self.class.name} choices=#{choices.inspect} usage=#{usage.inspect}>" + end + + # @param data [Hash] + # @param context [OmniAI::Context] optional + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:payload) + return deserialize.call(data, context:) if deserialize + + choices = data['choices'].map { |choice_data| Choice.deserialize(choice_data, context:) } + usage = Usage.deserialize(data['usage'], context:) if data['usage'] + + new(choices:, usage:) + end + + # @param context [OmniAI::Context] optional + # @return [Hash] + def serialize(context:) + serialize = context&.serializers&.[](:payload) + return serialize.call(self, context:) if serialize + + { + choices: choices.map { |choice| choice.serialize(context:) }, + usage: usage&.serialize(context:), + } + end + + # @param index [Integer] + # @return [Choice] + def choice(index: 0) + @choices[index] + end + + # @param index [Integer] + # @return [Message] + def message(index: 0) + choice(index:).message + end + + # @return [Array] + def messages + @choices.map(&:message) + end + + # @return [String, nil] + def content(index: 0) + message(index:).content + end + + # @return [Boolean] + def content?(index: 0) + message(index:).content? + end + + # @return [Array] + def tool_call_list + choice.tool_call_list + end + end + end +end diff --git a/lib/omniai/chat/prompt.rb b/lib/omniai/chat/prompt.rb index 7d32752..606de60 100644 --- a/lib/omniai/chat/prompt.rb +++ b/lib/omniai/chat/prompt.rb @@ -57,6 +57,11 @@ def initialize(messages: []) @messages = messages end + # @return [Prompt] + def dup + self.class.new(messages: @messages.dup) + end + # @return [String] def inspect "#<#{self.class.name} messages=#{@messages.inspect}>" diff --git a/lib/omniai/chat/response.rb b/lib/omniai/chat/response.rb new file mode 100644 index 0000000..038e5ec --- /dev/null +++ b/lib/omniai/chat/response.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # Used when processing everything at once. + class Response + # @return [Hash] + attr_accessor :data + + # @param data [Hash] + # @param context [Context, nil] + def initialize(data:, context: nil) + @data = data + @context = context + end + + # @return [Payload] + def completion + @completion ||= Payload.deserialize(@data, context: @context) + end + + # @return [Usage, nil] + def usage + completion.usage + end + + # @return [Array] + def choices + completion.choices + end + + # @return [Array] + def messages + completion.messages + end + + # @param index [Integer] + # @return [Choice] + def choice(index: 0) + completion.choice(index:) + end + + # @param index [Integer] + # @return [Message] + def message(index: 0) + completion.message(index:) + end + + # @return [String] + def content + choice.content + end + end + end +end diff --git a/lib/omniai/chat/response/choice.rb b/lib/omniai/chat/response/choice.rb deleted file mode 100644 index 62b6395..0000000 --- a/lib/omniai/chat/response/choice.rb +++ /dev/null @@ -1,35 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # For use with MessageChoice or DeltaChoice. - class Choice < Resource - # @return [Integer] - def index - @data['index'] - end - - # @return [Part] - def part - raise NotImplementedError, "#{self.class.name}#part undefined" - end - - # @return [ToolCallList] - def tool_call_list - part.tool_call_list - end - - # @return [String, nil] - def content - part.content - end - - # @return [Boolean] - def content? - !content.nil? - end - end - end - end -end diff --git a/lib/omniai/chat/response/chunk.rb b/lib/omniai/chat/response/chunk.rb deleted file mode 100644 index 8dc6e1b..0000000 --- a/lib/omniai/chat/response/chunk.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A chunk returned by the API. - class Chunk < Payload - # @return [Array] - def choices - @choices ||= @data['choices'].map { |data| DeltaChoice.new(data:) } - end - end - end - end -end diff --git a/lib/omniai/chat/response/completion.rb b/lib/omniai/chat/response/completion.rb deleted file mode 100644 index e83b848..0000000 --- a/lib/omniai/chat/response/completion.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A completion returned by the API. - class Completion < Payload - # @return [Array] - def choices - @choices ||= @data['choices'].map { |data| MessageChoice.new(data:) } - end - end - end - end -end diff --git a/lib/omniai/chat/response/delta.rb b/lib/omniai/chat/response/delta.rb deleted file mode 100644 index f22757a..0000000 --- a/lib/omniai/chat/response/delta.rb +++ /dev/null @@ -1,11 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A delta returned by the API. - class Delta < Part - end - end - end -end diff --git a/lib/omniai/chat/response/delta_choice.rb b/lib/omniai/chat/response/delta_choice.rb deleted file mode 100644 index 2e93d9e..0000000 --- a/lib/omniai/chat/response/delta_choice.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A delta choice returned by the API. - class DeltaChoice < Choice - # @return [String] - def inspect - "#<#{self.class.name} index=#{index} delta=#{delta.inspect}>" - end - - # @return [Delta] - def delta - @delta ||= Delta.new(data: @data['delta']) - end - - # @return [Delta] - def part - delta - end - end - end - end -end diff --git a/lib/omniai/chat/response/function.rb b/lib/omniai/chat/response/function.rb deleted file mode 100644 index 3357724..0000000 --- a/lib/omniai/chat/response/function.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A function returned by the API. - class Function < Resource - # @return [String] - def inspect - "#<#{self.class.name} name=#{name.inspect} arguments=#{arguments.inspect}>" - end - - # @return [String] - def name - @data['name'] - end - - # @return [Hash, nil] - def arguments - JSON.parse(@data['arguments']) if @data['arguments'] - end - end - end - end -end diff --git a/lib/omniai/chat/response/message.rb b/lib/omniai/chat/response/message.rb deleted file mode 100644 index 0d07bc2..0000000 --- a/lib/omniai/chat/response/message.rb +++ /dev/null @@ -1,11 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A message returned by the API. - class Message < Part - end - end - end -end diff --git a/lib/omniai/chat/response/message_choice.rb b/lib/omniai/chat/response/message_choice.rb deleted file mode 100644 index f0e2d22..0000000 --- a/lib/omniai/chat/response/message_choice.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A choice returned by the API. - class MessageChoice < Choice - # @return [String] - def inspect - "#<#{self.class.name} index=#{index} message=#{message.inspect}>" - end - - # @return [Message] - def message - @message ||= Message.new(data: @data['message']) - end - - # @return [Message] - def part - message - end - end - end - end -end diff --git a/lib/omniai/chat/response/part.rb b/lib/omniai/chat/response/part.rb deleted file mode 100644 index d790e25..0000000 --- a/lib/omniai/chat/response/part.rb +++ /dev/null @@ -1,38 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # Either a delta or message. - class Part < Resource - # @return [String] - def inspect - "#<#{self.class.name} role=#{role.inspect} content=#{content.inspect}>" - end - - # @return [String] - def role - @data['role'] || Role::USER - end - - # @return [String, nil] - def content - @data['content'] - end - - # @return [Array] - def tool_call_list - return [] unless @data['tool_calls'] - - @tool_call_list ||= @data['tool_calls'].map { |tool_call_data| ToolCall.new(data: tool_call_data) } - end - - # @param index [Integer] - # @return [ToolCall, nil] - def tool_call(index: 0) - tool_call_list[index] - end - end - end - end -end diff --git a/lib/omniai/chat/response/payload.rb b/lib/omniai/chat/response/payload.rb deleted file mode 100644 index 83f6081..0000000 --- a/lib/omniai/chat/response/payload.rb +++ /dev/null @@ -1,72 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A chunk or completion. - class Payload < Resource - # @return [String] - def inspect - "#<#{self.class.name} id=#{id.inspect} choices=#{choices.inspect}>" - end - - # @return [String] - def id - @data['id'] - end - - # @return [Time] - def created - Time.at(@data['created']) if @data['created'] - end - - # @return [Time] - def updated - Time.at(@data['updated']) if @data['updated'] - end - - # @return [String] - def model - @data['model'] - end - - # @return [Array] - def choices - raise NotImplementedError, "#{self.class.name}#choices undefined" - end - - # @param index [Integer] - # @return [DeltaChoice] - def choice(index: 0) - choices[index] - end - - # @param index [Integer] - # @return [Part] - def part(index: 0) - choice(index:).part - end - - # @return [Usage] - def usage - @usage ||= Usage.new(data: @data['usage']) if @data['usage'] - end - - # @return [String, nil] - def content - choice.content - end - - # @return [Boolean] - def content? - choice.content? - end - - # @return [Array] - def tool_call_list - choice.tool_call_list - end - end - end - end -end diff --git a/lib/omniai/chat/response/resource.rb b/lib/omniai/chat/response/resource.rb deleted file mode 100644 index 5ce9e61..0000000 --- a/lib/omniai/chat/response/resource.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A generic data to handle response. - class Resource - attr_accessor :data - - # @param data [Hash] - def initialize(data:) - @data = data - end - - # @return [String] - def inspect - "#<#{self.class.name} data=#{@data.inspect}>" - end - end - end - end -end diff --git a/lib/omniai/chat/response/stream.rb b/lib/omniai/chat/response/stream.rb deleted file mode 100644 index 7514fc3..0000000 --- a/lib/omniai/chat/response/stream.rb +++ /dev/null @@ -1,27 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A stream given when streaming. - class Stream - # @param response [HTTP::Response] - def initialize(response:) - @response = response - @parser = EventStreamParser::Parser.new - end - - # @yield [OmniAI::Chat::Chunk] - def stream! - @response.body.each do |chunk| - @parser.feed(chunk) do |_, data| - next if data.eql?('[DONE]') - - yield(Chunk.new(data: JSON.parse(data))) - end - end - end - end - end - end -end diff --git a/lib/omniai/chat/response/tool_call.rb b/lib/omniai/chat/response/tool_call.rb deleted file mode 100644 index 85074c5..0000000 --- a/lib/omniai/chat/response/tool_call.rb +++ /dev/null @@ -1,30 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A tool-call returned by the API. - class ToolCall < Resource - # @return [String] - def inspect - "#<#{self.class.name} id=#{id.inspect} type=#{type.inspect}>" - end - - # @return [String] - def id - @data['id'] - end - - # @return [String] - def type - @data['type'] - end - - # @return [Function] - def function - @function ||= Function.new(data: @data['function']) if @data['function'] - end - end - end - end -end diff --git a/lib/omniai/chat/response/usage.rb b/lib/omniai/chat/response/usage.rb deleted file mode 100644 index 6bcfaff..0000000 --- a/lib/omniai/chat/response/usage.rb +++ /dev/null @@ -1,35 +0,0 @@ -# frozen_string_literal: true - -module OmniAI - class Chat - module Response - # A usage returned by the API. - class Usage < Resource - # @return [String] - def inspect - properties = [ - "input_tokens=#{input_tokens}", - "output_tokens=#{output_tokens}", - "total_tokens=#{total_tokens}", - ] - "#<#{self.class.name} #{properties.join(' ')}>" - end - - # @return [Integer] - def input_tokens - @data['input_tokens'] || @data['prompt_tokens'] - end - - # @return [Integer] - def output_tokens - @data['output_tokens'] || @data['completion_tokens'] - end - - # @return [Integer] - def total_tokens - @data['total_tokens'] || (input_tokens + output_tokens) - end - end - end - end -end diff --git a/lib/omniai/chat/stream.rb b/lib/omniai/chat/stream.rb new file mode 100644 index 0000000..d2518dc --- /dev/null +++ b/lib/omniai/chat/stream.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # Used when streaming to process chunks of data. + class Stream + # @param body [HTTP::Response::Body] + # @param context [Context, nil] + def initialize(body:, context: nil) + @body = body + @context = context + end + + # @yield [OmniAI::Chat::Chunk] + def stream! + @body.each do |chunk| + parser.feed(chunk) do |_, data| + next if data.eql?('[DONE]') + + yield(Payload.deserialize(JSON.parse(data), context: @context)) + end + end + end + + protected + + # @return [EventStreamParser::Parser] + def parser + @parser ||= EventStreamParser::Parser.new + end + end + end +end diff --git a/lib/omniai/chat/tool_call.rb b/lib/omniai/chat/tool_call.rb new file mode 100644 index 0000000..8edf2c4 --- /dev/null +++ b/lib/omniai/chat/tool_call.rb @@ -0,0 +1,54 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # A tool-call that includes an ID / function. + class ToolCall + # @return [String] + attr_accessor :id + + # @return [Function] + attr_accessor :function + + # @param id [String] + # @param function [Function] + def initialize(id:, function:) + @id = id + @function = function + end + + # @return [String] + def inspect + "#<#{self.class.name} id=#{id.inspect} function=#{function.inspect}>" + end + + # @param data [Hash] + # @param context [Context] optional + # + # @return [Function] + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:tool_call) + return deserialize.call(data, context:) if deserialize + + id = data['id'] + function = Function.deserialize(data['function'], context:) + + new(id:, function:) + end + + # @param context [Context] optional + # + # @return [Hash] + def serialize(context: nil) + serializer = context&.serializers&.[](:tool_call) + return serializer.call(self, context:) if serializer + + { + id: @id, + type: 'function', + function: @function.serialize(context:), + } + end + end + end +end diff --git a/lib/omniai/chat/tool_call_batch.rb b/lib/omniai/chat/tool_call_batch.rb new file mode 100644 index 0000000..0381a92 --- /dev/null +++ b/lib/omniai/chat/tool_call_batch.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # A set of tool calls to be executed. + class ToolCallBatch + attr_accessor :entries + + # @return [Array] + def initialize(entries) + @entries = entries + end + end + end +end diff --git a/lib/omniai/chat/tool_message.rb b/lib/omniai/chat/tool_message.rb new file mode 100644 index 0000000..6c71fbc --- /dev/null +++ b/lib/omniai/chat/tool_message.rb @@ -0,0 +1,59 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # A specific message that contains the result of a tool call. + class ToolMessage + # @return [String] + attr_accessor :content + + # @return [ToolCall] + attr_accessor :tool_call + + # @param content [String] + # @param tool_call [ToolCall] + def initialize(content:, tool_call:) + @content = content + @tool_call = tool_call + end + + # @return [String] + def inspect + "#<#{self.class.name} content=#{content.inspect} tool_call=#{tool_call.inspect}>" + end + + # Usage: + # + # ToolCall.deserialize({ 'role' => 'tool', content: '{ 'temperature': 0 }' }) # => # + # + # @param data [Hash] + # @param context [Context] optional + # + # @return [ToolMessage] + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:tool_message) + return deserialize.call(data, context:) if deserialize + + content = JSON.parse(data['content']) + tool_call = ToolCall.deserialize(data['tool_call'], context:) if data['tool_call'] + + new(content:, tool_call:) + end + + # Usage: + # + # message.serialize # => { role: :user, content: 'Hello!' } + # message.serialize # => { role: :user, content: [{ type: 'text', text: 'Hello!' }] } + # + # @param context [Context] optional + # + # @return [Hash] + def serialize(context: nil) + serializer = context&.serializers&.[](:tool_message) + return serializer.call(self, context:) if serializer + + { role: 'tool', content: JSON.generate(@content), tool_call_id: @tool_call&.id } + end + end + end +end diff --git a/lib/omniai/chat/usage.rb b/lib/omniai/chat/usage.rb new file mode 100644 index 0000000..8304945 --- /dev/null +++ b/lib/omniai/chat/usage.rb @@ -0,0 +1,60 @@ +# frozen_string_literal: true + +module OmniAI + class Chat + # The usage of a chat in terms of tokens (input / output / total). + class Usage + # @return [Integer] + attr_accessor :input_tokens + + # @return [Integer] + attr_accessor :output_tokens + + # @return [Integer] + attr_accessor :total_tokens + + # @param input_tokens [Integer] + # @param output_tokens [Integer] + # @param total_tokens [Integer] + def initialize(input_tokens:, output_tokens:, total_tokens:) + @input_tokens = input_tokens + @output_tokens = output_tokens + @total_tokens = total_tokens + end + + # @return [String] + def inspect + "#<#{self.class.name} input_tokens=#{input_tokens} output_tokens=#{output_tokens} total_tokens=#{total_tokens}>" + end + + # @param data [Hash] + # @param context [OmniAI::Context] optional + # + # @return [OmniAI::Chat::Usage] + def self.deserialize(data, context: nil) + deserialize = context&.deserializers&.[](:usage) + return deserialize.call(data, context:) if deserialize + + input_tokens = data['input_tokens'] + output_tokens = data['output_tokens'] + total_tokens = data['total_tokens'] + + new(input_tokens:, output_tokens:, total_tokens:) + end + + # @param context [OmniAI::Context] optional + # + # @return [Hash] + def serialize(context: nil) + serialize = context&.serializers&.[](:usage) + return serialize.call(self, context:) if serialize + + { + input_tokens:, + output_tokens:, + total_tokens:, + } + end + end + end +end diff --git a/lib/omniai/version.rb b/lib/omniai/version.rb index d0d4359..504d78d 100644 --- a/lib/omniai/version.rb +++ b/lib/omniai/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module OmniAI - VERSION = '1.7.0' + VERSION = '1.8.0' end diff --git a/spec/factories/chat/choice.rb b/spec/factories/chat/choice.rb new file mode 100644 index 0000000..dae8490 --- /dev/null +++ b/spec/factories/chat/choice.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_choice, class: 'OmniAI::Chat::Choice' do + initialize_with { new(**attributes) } + + message factory: :chat_message + index { 0 } + end +end diff --git a/spec/factories/chat/content.rb b/spec/factories/chat/content.rb new file mode 100644 index 0000000..62497ac --- /dev/null +++ b/spec/factories/chat/content.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_content, class: 'OmniAI::Chat::Content' do + initialize_with { new(**attributes) } + end +end diff --git a/spec/factories/chat/file.rb b/spec/factories/chat/file.rb new file mode 100644 index 0000000..36a2a4a --- /dev/null +++ b/spec/factories/chat/file.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_file, class: 'OmniAI::Chat::File' do + initialize_with { new(io, type) } + + io { StringIO.new('...') } + type { 'image/png' } + end +end diff --git a/spec/factories/chat/function.rb b/spec/factories/chat/function.rb new file mode 100644 index 0000000..6b75938 --- /dev/null +++ b/spec/factories/chat/function.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_function, class: 'OmniAI::Chat::Function' do + initialize_with { new(**attributes) } + + name { 'temperature' } + arguments { { 'unit' => 'celsius' } } + end +end diff --git a/spec/factories/chat/media.rb b/spec/factories/chat/media.rb new file mode 100644 index 0000000..01c4d74 --- /dev/null +++ b/spec/factories/chat/media.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_media, class: 'OmniAI::Chat::Media' do + initialize_with { new(type) } + + type { 'image/png' } + end +end diff --git a/spec/factories/chat/message.rb b/spec/factories/chat/message.rb new file mode 100644 index 0000000..2ab37bb --- /dev/null +++ b/spec/factories/chat/message.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_message, class: 'OmniAI::Chat::Message' do + initialize_with { new(**attributes) } + + role { 'user' } + content { [] } + end +end diff --git a/spec/factories/chat/prompt.rb b/spec/factories/chat/prompt.rb new file mode 100644 index 0000000..c6232c9 --- /dev/null +++ b/spec/factories/chat/prompt.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_prompt, class: 'OmniAI::Chat::Prompt' do + initialize_with { new(**attributes) } + + messages { [] } + end +end diff --git a/spec/factories/chat/text.rb b/spec/factories/chat/text.rb new file mode 100644 index 0000000..1b4ba78 --- /dev/null +++ b/spec/factories/chat/text.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_text, class: 'OmniAI::Chat::Text' do + initialize_with { new(text) } + + text { 'Hello!' } + end +end diff --git a/spec/factories/chat/tool_call.rb b/spec/factories/chat/tool_call.rb new file mode 100644 index 0000000..94a6324 --- /dev/null +++ b/spec/factories/chat/tool_call.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_tool_call, class: 'OmniAI::Chat::ToolCall' do + initialize_with { new(**attributes) } + sequence(:id) + function factory: :chat_function + end +end diff --git a/spec/factories/chat/url.rb b/spec/factories/chat/url.rb new file mode 100644 index 0000000..d208f38 --- /dev/null +++ b/spec/factories/chat/url.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_url, class: 'OmniAI::Chat::URL' do + initialize_with { new(uri, type) } + + type { 'image/png' } + uri { 'https://localhost/hamster.png' } + end +end diff --git a/spec/factories/chat/usage.rb b/spec/factories/chat/usage.rb new file mode 100644 index 0000000..7c058dd --- /dev/null +++ b/spec/factories/chat/usage.rb @@ -0,0 +1,11 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :chat_usage, class: 'OmniAI::Chat::Usage' do + initialize_with { new(**attributes) } + + input_tokens { 2 } + output_tokens { 3 } + total_tokens { input_tokens + output_tokens } + end +end diff --git a/spec/factories/embed/usage.rb b/spec/factories/embed/usage.rb new file mode 100644 index 0000000..372628a --- /dev/null +++ b/spec/factories/embed/usage.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +FactoryBot.define do + factory :embed_usage, class: 'OmniAI::Embed::Usage' do + initialize_with { new(**attributes) } + + prompt_tokens { 2 } + total_tokens { 4 } + end +end diff --git a/spec/omniai/chat/choice_spec.rb b/spec/omniai/chat/choice_spec.rb new file mode 100644 index 0000000..eb08232 --- /dev/null +++ b/spec/omniai/chat/choice_spec.rb @@ -0,0 +1,70 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Chat::Choice do + subject(:choice) { build(:chat_choice, message:, index: 0) } + + let(:message) { build(:chat_message, role: 'user', content: 'Hello!') } + + describe '#inspect' do + it { expect(choice.inspect).to eq(%(#)) } + end + + describe '#index' do + it { expect(choice.index).to eq(0) } + end + + describe '#message' do + it { expect(choice.message).to eql(message) } + end + + describe '.deserialize' do + subject(:deserialize) { described_class.deserialize(data, context:) } + + let(:data) { { 'index' => 0, 'message' => { 'role' => 'user', 'content' => 'Hello!' } } } + + context 'with a deserializer' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:choice] = lambda { |data, *| + index = data['index'] + message = OmniAI::Chat::Message.deserialize(data['message'], context:) + described_class.new(message:, index:) + } + end + end + + it { expect(deserialize).to be_a(described_class) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(deserialize).to be_a(described_class) } + end + end + + describe '#serialize' do + subject(:serialize) { choice.serialize(context:) } + + context 'with a serializer' do + let(:context) do + OmniAI::Context.build do |context| + context.serializers[:choice] = lambda do |choice, *| + { + index: choice.index, + message: choice.message.serialize(context:), + } + end + end + end + + it { expect(serialize).to eq({ index: 0, message: { role: 'user', content: 'Hello!' } }) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(serialize).to eq({ index: 0, message: { role: 'user', content: 'Hello!' } }) } + end + end +end diff --git a/spec/omniai/chat/content_spec.rb b/spec/omniai/chat/content_spec.rb index 6962689..4cc1dcf 100644 --- a/spec/omniai/chat/content_spec.rb +++ b/spec/omniai/chat/content_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::Content do - subject(:content) { described_class.new } + subject(:content) { build(:chat_content) } describe '.summarize' do subject(:summarize) { described_class.summarize(content) } diff --git a/spec/omniai/chat/file_spec.rb b/spec/omniai/chat/file_spec.rb index 8422ec4..2d750f0 100644 --- a/spec/omniai/chat/file_spec.rb +++ b/spec/omniai/chat/file_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::File do - subject(:file) { described_class.new(io, type) } + subject(:file) { build(:chat_file, io:, type:) } let(:io) do Tempfile.new.tap do |tempfile| diff --git a/spec/omniai/chat/function_spec.rb b/spec/omniai/chat/function_spec.rb new file mode 100644 index 0000000..7ba1b00 --- /dev/null +++ b/spec/omniai/chat/function_spec.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Chat::Function do + subject(:function) { build(:chat_function, name:, arguments:) } + + let(:name) { 'temperature' } + let(:arguments) { { 'unit' => 'celsius' } } + + it { expect(function.name).to eq('temperature') } + it { expect(function.arguments).to eq({ 'unit' => 'celsius' }) } + + describe '#inspect' do + subject(:inspect) { function.inspect } + + it { is_expected.to eq '#"celsius"}>' } + end + + describe '.deserialize' do + subject(:deserialize) { described_class.deserialize(data, context:) } + + let(:data) { { 'name' => 'temperature', 'arguments' => '{"unit": "celsius"}' } } + + context 'with a deserializer' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:function] = lambda { |data, *| + name = data['name'] + arguments = JSON.parse(data['arguments']) + described_class.new(name:, arguments:) + } + end + end + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.name).to eq('temperature') } + it { expect(deserialize.arguments).to eq({ 'unit' => 'celsius' }) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.name).to eq('temperature') } + it { expect(deserialize.arguments).to eq({ 'unit' => 'celsius' }) } + end + end + + describe '#serialize' do + subject(:serialize) { function.serialize(context:) } + + context 'with a serializer' do + let(:context) do + OmniAI::Context.build do |context| + context.serializers[:function] = lambda do |function, *| + { + name: function.name, + arguments: JSON.generate(function.arguments), + } + end + end + end + + it { expect(serialize).to eq(name: 'temperature', arguments: '{"unit":"celsius"}') } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(serialize).to eq(name: 'temperature', arguments: '{"unit":"celsius"}') } + end + end +end diff --git a/spec/omniai/chat/media_spec.rb b/spec/omniai/chat/media_spec.rb index fe854c4..70bccd2 100644 --- a/spec/omniai/chat/media_spec.rb +++ b/spec/omniai/chat/media_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::Media do - subject(:media) { described_class.new(type) } + subject(:media) { build(:chat_media, type:) } let(:type) { 'text/plain' } diff --git a/spec/omniai/chat/message_spec.rb b/spec/omniai/chat/message_spec.rb index d028263..9618142 100644 --- a/spec/omniai/chat/message_spec.rb +++ b/spec/omniai/chat/message_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::Message do - subject(:message) { described_class.new(role:, content:) } + subject(:message) { build(:chat_message, role:, content:) } let(:role) { OmniAI::Chat::Role::USER } let(:content) { [] } diff --git a/spec/omniai/chat/payload_spec.rb b/spec/omniai/chat/payload_spec.rb new file mode 100644 index 0000000..1342f23 --- /dev/null +++ b/spec/omniai/chat/payload_spec.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Chat::Payload do + subject(:payload) { described_class.new(choices:, usage:) } + + let(:usage) { build(:chat_usage) } + let(:choices) { [] } + + describe '#inspect' do + it { expect(payload.inspect).to eql("#") } + end + + describe '.deserialize' do + subject(:deserialize) { described_class.deserialize(data, context:) } + + let(:data) do + { + 'choices' => [], + 'usage' => { + 'input_tokens' => 0, + 'output_tokens' => 0, + 'total_tokens' => 0, + }, + } + end + + context 'with a deserializer' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:payload] = lambda { |data, *| + choices = data['choices'].map { Choice.deserialize(data, context:) } + usage = OmniAI::Chat::Usage.deserialize(data['usage'], context:) + described_class.new(choices:, usage:) + } + end + end + + it { expect(deserialize).to be_a(described_class) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(deserialize).to be_a(described_class) } + end + end + + describe '#serialize' do + subject(:serialize) { payload.serialize(context:) } + + context 'with a serializer' do + let(:context) do + OmniAI::Context.build do |context| + context.serializers[:payload] = lambda do |payload, *| + { + choices: payload.choices.map { |choice| choice.serialize(context:) }, + usage: payload.usage&.serialize(context:), + } + end + end + end + + it { expect(serialize).to eq(choices: [], usage: { input_tokens: 2, output_tokens: 3, total_tokens: 5 }) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(serialize).to eq(choices: [], usage: { input_tokens: 2, output_tokens: 3, total_tokens: 5 }) } + end + end +end diff --git a/spec/omniai/chat/prompt_spec.rb b/spec/omniai/chat/prompt_spec.rb index cbc2505..271d6b3 100644 --- a/spec/omniai/chat/prompt_spec.rb +++ b/spec/omniai/chat/prompt_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::Prompt do - subject(:prompt) { described_class.new(messages:) } + subject(:prompt) { build(:chat_prompt, messages:) } let(:messages) { [] } diff --git a/spec/omniai/chat/response/choice_spec.rb b/spec/omniai/chat/response/choice_spec.rb deleted file mode 100644 index 4c7c072..0000000 --- a/spec/omniai/chat/response/choice_spec.rb +++ /dev/null @@ -1,19 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Choice do - subject(:choice) { described_class.new(data:) } - - let(:data) { { 'index' => 0 } } - - describe '#index' do - it { expect(choice.index).to eq(0) } - end - - describe '#part' do - it { expect { choice.part }.to raise_error(NotImplementedError) } - end - - describe '#tool_call_list' do - it { expect { choice.tool_call_list }.to raise_error(NotImplementedError) } - end -end diff --git a/spec/omniai/chat/response/chunk_spec.rb b/spec/omniai/chat/response/chunk_spec.rb deleted file mode 100644 index 1554ca9..0000000 --- a/spec/omniai/chat/response/chunk_spec.rb +++ /dev/null @@ -1,39 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Chunk do - subject(:chunk) { described_class.new(data:) } - - let(:data) do - { - 'id' => 'fake_id', - 'model' => 'fake_model', - 'created' => 0, - 'updated' => 0, - 'choices' => [], - } - end - - describe '#id' do - it { expect(chunk.id).to eq('fake_id') } - end - - describe '#model' do - it { expect(chunk.model).to eq('fake_model') } - end - - describe '#created' do - it { expect(chunk.created).to be_a(Time) } - end - - describe '#updated' do - it { expect(chunk.updated).to be_a(Time) } - end - - describe '#choices' do - it { expect(chunk.choices).to be_empty } - end - - describe '#inspect' do - it { expect(chunk.inspect).to eq('#') } - end -end diff --git a/spec/omniai/chat/response/completion_spec.rb b/spec/omniai/chat/response/completion_spec.rb deleted file mode 100644 index 3d9966f..0000000 --- a/spec/omniai/chat/response/completion_spec.rb +++ /dev/null @@ -1,48 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Completion do - subject(:completion) { described_class.new(data:) } - - let(:data) do - { - 'id' => 'fake_id', - 'model' => 'fake_model', - 'created' => 0, - 'updated' => 0, - 'choices' => [], - 'usage' => { - 'input_tokens' => 0, - 'output_tokens' => 0, - 'total_tokens' => 0, - }, - } - end - - describe '#id' do - it { expect(completion.id).to eq('fake_id') } - end - - describe '#model' do - it { expect(completion.model).to eq('fake_model') } - end - - describe '#created' do - it { expect(completion.created).to be_a(Time) } - end - - describe '#updated' do - it { expect(completion.updated).to be_a(Time) } - end - - describe '#usage' do - it { expect(completion.usage).to be_a(OmniAI::Chat::Response::Usage) } - end - - describe '#choices' do - it { expect(completion.choices).to be_empty } - end - - describe '#inspect' do - it { expect(completion.inspect).to eql('#') } - end -end diff --git a/spec/omniai/chat/response/delta_choice_spec.rb b/spec/omniai/chat/response/delta_choice_spec.rb deleted file mode 100644 index fddb0ed..0000000 --- a/spec/omniai/chat/response/delta_choice_spec.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::DeltaChoice do - subject(:choice) { described_class.new(data:) } - - let(:data) { { 'index' => 0, 'delta' => { 'role' => 'user', 'content' => 'Hello!' } } } - - it { expect(choice.index).to eq(0) } - it { expect(choice.delta).not_to be_nil } - it { expect(choice.delta.role).to eq('user') } - it { expect(choice.delta.content).to eq('Hello!') } - - describe '#inspect' do - let(:delta) { OmniAI::Chat::Response::Delta.new(data: data['delta']) } - - it { expect(choice.inspect).to eq(%(#)) } - end - - describe '#part' do - it { expect(choice.part).to be_a(OmniAI::Chat::Response::Delta) } - end -end diff --git a/spec/omniai/chat/response/delta_spec.rb b/spec/omniai/chat/response/delta_spec.rb deleted file mode 100644 index 9d71080..0000000 --- a/spec/omniai/chat/response/delta_spec.rb +++ /dev/null @@ -1,19 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Delta do - subject(:delta) { described_class.new(data:) } - - let(:data) { { 'role' => 'user', 'content' => 'Hello!' } } - - describe '#role' do - it { expect(delta.role).to eq('user') } - end - - describe '#content' do - it { expect(delta.content).to eq('Hello!') } - end - - describe '#inspect' do - it { expect(delta.inspect).to eq('#') } - end -end diff --git a/spec/omniai/chat/response/function_spec.rb b/spec/omniai/chat/response/function_spec.rb deleted file mode 100644 index 235c37d..0000000 --- a/spec/omniai/chat/response/function_spec.rb +++ /dev/null @@ -1,16 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Function do - subject(:function) { described_class.new(data:) } - - let(:data) { { 'name' => 'temperature', 'arguments' => '{ "unit": "celsius" }' } } - - it { expect(function.name).to eq('temperature') } - it { expect(function.arguments).to eq({ 'unit' => 'celsius' }) } - - describe '#inspect' do - subject(:inspect) { function.inspect } - - it { is_expected.to eq '#"celsius"}>' } - end -end diff --git a/spec/omniai/chat/response/message_choice_spec.rb b/spec/omniai/chat/response/message_choice_spec.rb deleted file mode 100644 index 41413d0..0000000 --- a/spec/omniai/chat/response/message_choice_spec.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::MessageChoice do - subject(:choice) { described_class.new(data:) } - - let(:data) { { 'index' => 0, 'message' => { 'role' => 'user', 'content' => 'Hello!' } } } - - it { expect(choice.index).to eq(0) } - it { expect(choice.message).not_to be_nil } - it { expect(choice.message.role).to eq('user') } - it { expect(choice.message.content).to eq('Hello!') } - - describe '#inspect' do - let(:message) { OmniAI::Chat::Response::Message.new(data: data['message']) } - - it { expect(choice.inspect).to eq(%(#)) } - end - - describe '#part' do - it { expect(choice.part).to be_a(OmniAI::Chat::Response::Message) } - end -end diff --git a/spec/omniai/chat/response/message_spec.rb b/spec/omniai/chat/response/message_spec.rb deleted file mode 100644 index 1e4d591..0000000 --- a/spec/omniai/chat/response/message_spec.rb +++ /dev/null @@ -1,19 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Message do - subject(:message) { described_class.new(data:) } - - let(:data) { { 'role' => 'user', 'content' => 'Hello!' } } - - describe '#role' do - it { expect(message.role).to eq('user') } - end - - describe '#content' do - it { expect(message.content).to eq('Hello!') } - end - - describe '#inspect' do - it { expect(message.inspect).to eq('#') } - end -end diff --git a/spec/omniai/chat/response/part_spec.rb b/spec/omniai/chat/response/part_spec.rb deleted file mode 100644 index a5fe10e..0000000 --- a/spec/omniai/chat/response/part_spec.rb +++ /dev/null @@ -1,38 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Part do - subject(:part) { described_class.new(data:) } - - let(:data) do - { - 'role' => 'system', - 'content' => 'Hello!', - 'tool_calls' => [ - { - 'id' => 'fake_tool_call_id', - 'type' => 'function', - 'function' => { - 'name' => 'temperature', - 'arguments' => '{"unit":"celsius"}', - }, - }, - ], - } - end - - describe '#role' do - it { expect(part.role).to eq('system') } - end - - describe '#content' do - it { expect(part.content).to eq('Hello!') } - end - - describe '#tool_call_list' do - it { expect(part.tool_call_list).not_to be_empty } - end - - describe '#tool_call' do - it { expect(part.tool_call).to be_a(OmniAI::Chat::Response::ToolCall) } - end -end diff --git a/spec/omniai/chat/response/payload_spec.rb b/spec/omniai/chat/response/payload_spec.rb deleted file mode 100644 index 918dcc9..0000000 --- a/spec/omniai/chat/response/payload_spec.rb +++ /dev/null @@ -1,62 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Payload do - subject(:completion) { described_class.new(data:) } - - let(:data) do - { - 'id' => 'fake_id', - 'model' => 'fake_model', - 'created' => 0, - 'updated' => 0, - 'choices' => [], - 'usage' => { - 'input_tokens' => 0, - 'output_tokens' => 0, - 'total_tokens' => 0, - }, - } - end - - describe '#id' do - it { expect(completion.id).to eq('fake_id') } - end - - describe '#model' do - it { expect(completion.model).to eq('fake_model') } - end - - describe '#created' do - it { expect(completion.created).to be_a(Time) } - - context 'without created' do - let(:data) { {} } - - it { expect(completion.created).to be_nil } - end - end - - describe '#updated' do - it { expect(completion.updated).to be_a(Time) } - - context 'without updated' do - let(:data) { {} } - - it { expect(completion.updated).to be_nil } - end - end - - describe '#usage' do - it { expect(completion.usage).to be_a(OmniAI::Chat::Response::Usage) } - - context 'without usage' do - let(:data) { {} } - - it { expect(completion.usage).to be_nil } - end - end - - describe '#choices' do - it { expect { completion.choices }.to raise_error(NotImplementedError) } - end -end diff --git a/spec/omniai/chat/response/resource_spec.rb b/spec/omniai/chat/response/resource_spec.rb deleted file mode 100644 index 7b69d52..0000000 --- a/spec/omniai/chat/response/resource_spec.rb +++ /dev/null @@ -1,17 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Resource do - subject(:resource) { described_class.new(data: { name: 'Ringo' }) } - - describe '#data' do - it 'returns data' do - expect(resource.data).to eq({ name: 'Ringo' }) - end - end - - describe '#inspect' do - it 'returns inspect string' do - expect(resource.inspect).to eq('#"Ringo"}>') - end - end -end diff --git a/spec/omniai/chat/response/tool_call_spec.rb b/spec/omniai/chat/response/tool_call_spec.rb deleted file mode 100644 index 77aad0c..0000000 --- a/spec/omniai/chat/response/tool_call_spec.rb +++ /dev/null @@ -1,23 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::ToolCall do - subject(:tool_call) { described_class.new(data:) } - - let(:data) do - { - 'id' => 'fake_tool_call_id', - 'type' => 'function', - 'function' => { 'name' => 'temperature', 'arguments' => '{"unit":"celsius"}' }, - } - end - - it { expect(tool_call.id).to eq('fake_tool_call_id') } - it { expect(tool_call.type).to eq('function') } - it { expect(tool_call.function).to be_a(OmniAI::Chat::Response::Function) } - - describe '#inspect' do - subject(:inspect) { tool_call.inspect } - - it { expect(inspect).to eq('#') } - end -end diff --git a/spec/omniai/chat/response/usage_spec.rb b/spec/omniai/chat/response/usage_spec.rb deleted file mode 100644 index 7b7d1b6..0000000 --- a/spec/omniai/chat/response/usage_spec.rb +++ /dev/null @@ -1,38 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe OmniAI::Chat::Response::Usage do - subject(:usage) { described_class.new(data:) } - - let(:data) { { 'input_tokens' => 2, 'output_tokens' => 3, 'total_tokens' => 5 } } - - context 'with input_tokens / output_tokens' do - let(:data) do - { - 'input_tokens' => 2, - 'output_tokens' => 3, - } - end - - it { expect(usage.input_tokens).to eq(2) } - it { expect(usage.output_tokens).to eq(3) } - it { expect(usage.total_tokens).to eq(5) } - end - - context 'with prompt_tokens / completion_tokens / total_tokens' do - let(:data) do - { - 'prompt_tokens' => 2, - 'completion_tokens' => 3, - 'total_tokens' => 5, - } - end - - it { expect(usage.input_tokens).to eq(2) } - it { expect(usage.output_tokens).to eq(3) } - it { expect(usage.total_tokens).to eq(5) } - end - - describe '#inspect' do - it { expect(usage.inspect).to eq('#') } - end -end diff --git a/spec/omniai/chat/text_spec.rb b/spec/omniai/chat/text_spec.rb index 9a22a12..f028af5 100644 --- a/spec/omniai/chat/text_spec.rb +++ b/spec/omniai/chat/text_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::Text do - subject(:text) { described_class.new('Hello!') } + subject(:text) { build(:chat_text, text: 'Hello!') } describe '#text' do it { expect(text.text).to eq('Hello!') } diff --git a/spec/omniai/chat/tool_call_spec.rb b/spec/omniai/chat/tool_call_spec.rb new file mode 100644 index 0000000..44e1993 --- /dev/null +++ b/spec/omniai/chat/tool_call_spec.rb @@ -0,0 +1,98 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Chat::ToolCall do + subject(:tool_call) { build(:chat_tool_call, id:, function:) } + + let(:id) { 'fake_tool_call_id' } + let(:function) { build(:chat_function, name: 'temperature', arguments: { 'unit' => 'celsius' }) } + + describe '#id' do + it { expect(tool_call.id).to eq(id) } + end + + describe '#function' do + it { expect(tool_call.function).to eq(function) } + end + + describe '#inspect' do + subject(:inspect) { tool_call.inspect } + + it { is_expected.to eq(%(#)) } + end + + describe '.deserialize' do + subject(:deserialize) { described_class.deserialize(data, context:) } + + let(:data) do + { + 'id' => 'fake_tool_call_id', + 'function' => { + 'name' => 'temperature', + 'arguments' => '{"unit":"celsius"}', + }, + } + end + + context 'with a deserializer' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:tool_call] = lambda { |data, *| + id = data['id'] + function = OmniAI::Chat::Function.deserialize(data['function']) + described_class.new(id:, function:) + } + end + end + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.id).to eq('fake_tool_call_id') } + it { expect(deserialize.function).to be_a(OmniAI::Chat::Function) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.id).to eq('fake_tool_call_id') } + it { expect(deserialize.function).to be_a(OmniAI::Chat::Function) } + end + end + + describe '#serialize' do + subject(:serialize) { tool_call.serialize(context:) } + + context 'with a serializer' do + let(:context) do + OmniAI::Context.build do |context| + context.serializers[:tool_call] = lambda do |tool_call, *| + { + id: tool_call.id, + type: 'function', + function: tool_call.function.serialize(context:), + } + end + end + end + + it do + expect(serialize).to eq( + id: 'fake_tool_call_id', + type: 'function', + function: { name: 'temperature', arguments: '{"unit":"celsius"}' } + ) + end + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it do + expect(serialize).to eq( + id: 'fake_tool_call_id', + type: 'function', + function: { name: 'temperature', arguments: '{"unit":"celsius"}' } + ) + end + end + end +end diff --git a/spec/omniai/chat/url_spec.rb b/spec/omniai/chat/url_spec.rb index 37f808e..e605b31 100644 --- a/spec/omniai/chat/url_spec.rb +++ b/spec/omniai/chat/url_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Chat::URL do - subject(:url) { described_class.new(uri, type) } + subject(:url) { build(:chat_url, uri:, type:) } let(:type) { 'image/png' } let(:uri) { 'https://localhost/hamster.png' } diff --git a/spec/omniai/chat/usage_spec.rb b/spec/omniai/chat/usage_spec.rb new file mode 100644 index 0000000..36b4dc7 --- /dev/null +++ b/spec/omniai/chat/usage_spec.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +RSpec.describe OmniAI::Chat::Usage do + subject(:usage) { build(:chat_usage, input_tokens: 2, output_tokens: 3, total_tokens: 5) } + + describe '#input_tokens' do + it { expect(usage.input_tokens).to eq(2) } + end + + describe '#output_tokens' do + it { expect(usage.output_tokens).to eq(3) } + end + + describe '#total_tokens' do + it { expect(usage.total_tokens).to eq(5) } + end + + describe '#inspect' do + it { expect(usage.inspect).to eq('#') } + end + + describe '.deserialize' do + subject(:deserialize) { described_class.deserialize(data, context:) } + + let(:data) { { 'input_tokens' => 2, 'output_tokens' => 3, 'total_tokens' => 5 } } + + context 'with a deserializer' do + let(:context) do + OmniAI::Context.build do |context| + context.deserializers[:usage] = lambda { |data, *| + input_tokens = data['input_tokens'] + output_tokens = data['output_tokens'] + total_tokens = data['total_tokens'] + described_class.new(input_tokens:, output_tokens:, total_tokens:) + } + end + end + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.input_tokens).to eq(2) } + it { expect(deserialize.output_tokens).to eq(3) } + it { expect(deserialize.total_tokens).to eq(5) } + end + + context 'without a deserializer' do + let(:context) { OmniAI::Context.build } + + it { expect(deserialize).to be_a(described_class) } + it { expect(deserialize.input_tokens).to eq(2) } + it { expect(deserialize.output_tokens).to eq(3) } + it { expect(deserialize.total_tokens).to eq(5) } + end + end + + describe '#serialize' do + subject(:serialize) { usage.serialize(context:) } + + context 'with a serializer' do + let(:context) do + OmniAI::Context.build do |context| + context.serializers[:usage] = lambda do |usage, *| + { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.total_tokens, + } + end + end + end + + it { is_expected.to eq(input_tokens: 2, output_tokens: 3, total_tokens: 5) } + end + + context 'without a serializer' do + let(:context) { OmniAI::Context.build } + + it { is_expected.to eq(input_tokens: 2, output_tokens: 3, total_tokens: 5) } + end + end +end diff --git a/spec/omniai/chat_spec.rb b/spec/omniai/chat_spec.rb index 3f7dc09..7028611 100644 --- a/spec/omniai/chat_spec.rb +++ b/spec/omniai/chat_spec.rb @@ -86,13 +86,14 @@ def payload index: 0, message: { role: 'system', - content: '{ "name": "Ringo" }', + content: 'Ringo!', }, }], }) end - it { expect(process!).to be_a(OmniAI::Chat::Response::Completion) } + it { expect(process!).to be_a(OmniAI::Chat::Response) } + it { expect(process!.content).to eql('Ringo!') } end context 'when UNPROCESSABLE' do @@ -134,7 +135,7 @@ def payload chunks = [] allow(stream).to receive(:call) { |chunk| chunks << chunk } process! - expect(chunks.map { |chunk| chunk.choice.delta.content }).to eql(%w[A B]) + expect(chunks.map(&:content)).to eql(%w[A B]) end end diff --git a/spec/omniai/embed/usage_spec.rb b/spec/omniai/embed/usage_spec.rb index 39f22b8..c195046 100644 --- a/spec/omniai/embed/usage_spec.rb +++ b/spec/omniai/embed/usage_spec.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true RSpec.describe OmniAI::Embed::Usage do - subject(:usage) { described_class.new(prompt_tokens:, total_tokens:) } + subject(:usage) { build(:embed_usage, prompt_tokens:, total_tokens:) } let(:prompt_tokens) { 2 } let(:total_tokens) { 4 } diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index 4b58a67..8f9f0ce 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -10,6 +10,8 @@ require 'omniai' +Dir["#{__dir__}/support/**/*.rb"].each { |file| require file } + RSpec.configure do |config| config.expect_with :rspec do |c| c.syntax = :expect diff --git a/spec/support/factory_bot.rb b/spec/support/factory_bot.rb new file mode 100644 index 0000000..c057de9 --- /dev/null +++ b/spec/support/factory_bot.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +require 'factory_bot' + +FactoryBot.find_definitions + +RSpec.configure do |config| + config.include FactoryBot::Syntax::Methods +end