Skip to content

Commit

Permalink
Properly implement serialize / deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Aug 9, 2024
1 parent a9e3e09 commit 0559fc9
Show file tree
Hide file tree
Showing 72 changed files with 1,077 additions and 831 deletions.
1 change: 1 addition & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ source 'https://rubygems.org'

gemspec

gem 'factory_bot'
gem 'logger'
gem 'rake'
gem 'rspec'
Expand Down
38 changes: 30 additions & 8 deletions Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
PATH
remote: .
specs:
omniai (1.7.0)
omniai (1.8.0)
event_stream_parser
http
zeitwerk

GEM
remote: https://rubygems.org/
specs:
activesupport (7.1.3.4)
base64
bigdecimal
concurrent-ruby (~> 1.0, >= 1.0.2)
connection_pool (>= 2.2.5)
drb
i18n (>= 1.6, < 2)
minitest (>= 5.1)
mutex_m
tzinfo (~> 2.0)
addressable (2.8.7)
public_suffix (>= 2.0.2, < 7.0)
ast (2.4.2)
base64 (0.2.0)
bigdecimal (3.1.8)
concurrent-ruby (1.3.3)
connection_pool (2.4.1)
crack (1.0.0)
bigdecimal
rexml
diff-lcs (1.5.1)
docile (1.4.0)
docile (1.4.1)
domain_name (0.6.20240107)
drb (2.2.1)
event_stream_parser (1.0.0)
factory_bot (6.4.6)
activesupport (>= 5.0.0)
ffi (1.17.0)
ffi (1.17.0-aarch64-linux-gnu)
ffi (1.17.0-aarch64-linux-musl)
Expand All @@ -35,7 +50,7 @@ GEM
ffi-compiler (1.3.2)
ffi (>= 1.15.5)
rake
hashdiff (1.1.0)
hashdiff (1.1.1)
http (5.2.0)
addressable (~> 2.8)
base64 (~> 0.1)
Expand All @@ -45,17 +60,21 @@ GEM
http-cookie (1.0.6)
domain_name (~> 0.5)
http-form_data (2.3.0)
i18n (1.14.5)
concurrent-ruby (~> 1.0)
json (2.7.2)
language_server-protocol (3.17.0.3)
llhttp-ffi (0.5.0)
ffi-compiler (~> 1.0)
rake (~> 13.0)
logger (1.6.0)
parallel (1.25.1)
parser (3.3.4.0)
minitest (5.24.1)
mutex_m (0.2.0)
parallel (1.26.1)
parser (3.3.4.2)
ast (~> 2.4.1)
racc
public_suffix (6.0.0)
public_suffix (6.0.1)
racc (1.8.1)
rainbow (3.1.1)
rake (13.2.1)
Expand Down Expand Up @@ -88,11 +107,11 @@ GEM
rubocop-ast (>= 1.31.1, < 2.0)
ruby-progressbar (~> 1.7)
unicode-display_width (>= 2.4.0, < 3.0)
rubocop-ast (1.31.3)
rubocop-ast (1.32.0)
parser (>= 3.3.1.0)
rubocop-rake (0.6.0)
rubocop (~> 1.0)
rubocop-rspec (3.0.3)
rubocop-rspec (3.0.4)
rubocop (~> 1.61)
ruby-progressbar (1.13.0)
simplecov (0.22.0)
Expand All @@ -102,6 +121,8 @@ GEM
simplecov-html (0.12.3)
simplecov_json_formatter (0.1.4)
strscan (3.1.0)
tzinfo (2.0.6)
concurrent-ruby (~> 1.0)
unicode-display_width (2.5.0)
webmock (3.23.1)
addressable (>= 2.8.0)
Expand All @@ -124,6 +145,7 @@ PLATFORMS
x86_64-linux-musl

DEPENDENCIES
factory_bot
logger
omniai!
rake
Expand Down
17 changes: 5 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ Clients that support chat (e.g. Anthropic w/ "Claude", Google w/ "Gemini", Mistr
Generating a completion is as simple as sending in the text:

```ruby
completion = client.chat('Tell me a joke.')
completion.choice.message.content # 'Why don't scientists trust atoms? They make up everything!'
response = client.chat('Tell me a joke.')
response.content # 'Why don't scientists trust atoms? They make up everything!'
```

#### Completions using a Complex Prompt

More complex completions are generated using a block w/ various system / user messages:

```ruby
completion = client.chat do |prompt|
response = client.chat do |prompt|
prompt.system 'You are a helpful assistant with an expertise in animals.'
prompt.user do |message|
message.text 'What animals are in the attached photos?'
Expand All @@ -145,7 +145,7 @@ completion = client.chat do |prompt|
message.file('./hamster.jpeg', "image/jpeg")
end
end
completion.choice.message.content # 'They are photos of a cat, a cat, and a hamster.'
response.content # 'They are photos of a cat, a cat, and a hamster.'
```

#### Completions using Streaming via Proc
Expand All @@ -154,7 +154,7 @@ A real-time stream of messages can be generated by passing in a proc:

```ruby
stream = proc do |chunk|
print(chunk.choice.delta.content) # '...'
print(chunk.content) # '...'
end
client.chat('Tell me a joke.', stream:)
```
Expand Down Expand Up @@ -315,10 +315,3 @@ Type 'exit' or 'quit' to abort.
0.0
...
```

0.0
...

```
```
44 changes: 21 additions & 23 deletions lib/omniai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def process!

protected

# Override to provide an context for serializers / deserializes for a provider.
#
# @return [Context, nil]
def context
nil
end

# Used to spawn another chat with the same configuration using different messages.
#
# @param prompt [OmniAI::Chat::Prompt]
Expand Down Expand Up @@ -114,7 +121,7 @@ def path
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Completion]
# @return [OmniAI::Chat::Response]
def parse!(response:)
if @stream
stream!(response:)
Expand All @@ -124,27 +131,28 @@ def parse!(response:)
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Completion]
# @return [OmniAI::Chat::Response]
def complete!(response:)
completion = self.class::Response::Completion.new(data: response.parse)
completion = Response.new(data: response.parse, context:)

if @tools && completion.tool_call_list.any?
spawn!([
*@prompt.serialize,
*completion.choices.map(&:message).map(&:data),
*(completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) }),
])
spawn!(
@prompt.dup.tap do |prompt|
prompt.messages += completion.messages
prompt.messages += completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) }
end
)
else
completion
end
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Stream]
# @return [OmniAI::Chat::Stream]
def stream!(response:)
raise Error, "#{self.class.name}#stream! unstreamable" unless @stream

self.class::Response::Stream.new(response:).stream! do |chunk|
Stream.new(body: response.body, context:).stream! do |chunk|
case @stream
when IO, StringIO
if chunk.content?
Expand All @@ -167,24 +175,14 @@ def request!
end

# @param tool_call [OmniAI::Chat::ToolCall]
# @return [ToolMessage]
def execute_tool_call(tool_call)
function = tool_call.function

tool = @tools.find { |entry| function.name == entry.name } || raise(ToolCallLookupError, tool_call)
result = tool.call(function.arguments)
content = tool.call(function.arguments)

prepare_tool_call_message(tool_call:, content: result)
end

# @param tool_call [OmniAI::Chat::ToolCall]
# @param content [String]
def prepare_tool_call_message(tool_call:, content:)
{
role: Role::TOOL,
name: tool_call.function.name,
tool_call_id: tool_call.id,
content:,
}
ToolMessage.new(tool_call:, content:)
end
end
end
55 changes: 55 additions & 0 deletions lib/omniai/chat/choice.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# frozen_string_literal: true

module OmniAI
class Chat
# For for.
class Choice
# @return [Integer]
attr_accessor :index

# @return [Message]
attr_accessor :message

# @param message [Message]
def initialize(message:, index: 0)
@message = message
@index = index
end

# @return [String]
def inspect
"#<#{self.class.name} index=#{@index} message=#{@message.inspect}>"
end

# @param data [Hash]
# @param context [OmniAI::Context] optional
#
# @return [Choice]
def self.deserialize(data, context: nil)
deserialize = context&.deserializers&.[](:choice)
return deserialize.call(data, context:) if deserialize

index = data['index']
message = Message.deserialize(data['message'] || data['delta'], context:)
new(message:, index:)
end

# @param context [OmniAI::Context] optional
# @return [Hash]
def serialize(context: nil)
serialize = context&.serializers&.[](:choice)
return serialize.call(self, context:) if serialize

{
index:,
message: message.serialize(context:),
}
end

# @return [Array<Content>, String]
def content
message.content
end
end
end
end
12 changes: 8 additions & 4 deletions lib/omniai/chat/content.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ def self.summarize(content)
#
# @return [String]
def serialize(context: nil)
raise NotImplementedError, ' # {self.class}#serialize undefined'
raise NotImplementedError, "#{self.class}#serialize undefined"
end

# @param data [hash]
# @param data [Hash, Array, String]
# @param context [Context] optional
#
# @return [Content]
def self.deserialize(data, context: nil)
raise ArgumentError, "untyped data=#{data.inspect}" unless data.key?('type')
return data.map { |data| deserialize(data, context:) } if data.is_a?(Array)

deserialize = context&.deserializers&.[](:content)
return deserialize.call(data, context:) if deserialize

return data if data.is_a?(String)

case data['type']
when 'text' then Text.deserialize(data, context:)
when /(.*)_url/ then URL.deserialize(data, context:)
else raise ArgumentError, "unknown type=#{data['type'].inspect}"
end
end
end
Expand Down
53 changes: 53 additions & 0 deletions lib/omniai/chat/function.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# frozen_string_literal: true

module OmniAI
class Chat
# A function that includes a name / arguments.
class Function
# @return [String]
attr_accessor :name

# @return [Hash]
attr_accessor :arguments

# @param name [String]
# @param arguments [Hash]
def initialize(name:, arguments:)
@name = name
@arguments = arguments
end

# @return [String]
def inspect
"#<#{self.class.name} name=#{name.inspect} arguments=#{arguments.inspect}>"
end

# @param data [Hash]
# @param context [Context] optional
#
# @return [Function]
def self.deserialize(data, context: nil)
deserialize = context&.deserializers&.[](:function)
return deserialize.call(data, context:) if deserialize

name = data['name']
arguments = JSON.parse(data['arguments']) if data['arguments']

new(name:, arguments:)
end

# @param context [Context] optional
#
# @return [Hash]
def serialize(context: nil)
serializer = context&.serializers&.[](:function)
return serializer.call(self, context:) if serializer

{
name: @name,
arguments: (JSON.generate(@arguments) if @arguments),
}
end
end
end
end
Loading

0 comments on commit 0559fc9

Please sign in to comment.