From 971e03bdf2e2717ea7d429576702d2cbdf2ca9a2 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Thu, 4 Jan 2024 18:15:34 -0300 Subject: [PATCH] FEATURE: AI Bot Gemini support. (#402) It also corrects the syntax around tool support, which was wrong. Gemini doesn't want us to include messages about previous tool invocations, so I had to shuffle around some code to send the response it generated from those invocations instead. For this, I created the "multi_turn" context, which bundles all the context involved in the interaction. --- .../initializers/ai-bot-replies.js | 2 +- config/locales/client.en.yml | 1 + config/settings.yml | 1 + lib/ai_bot/bot.rb | 5 +- lib/ai_bot/entry_point.rb | 4 + lib/ai_bot/playground.rb | 16 ++-- lib/completions/dialects/chat_gpt.rb | 3 +- lib/completions/dialects/claude.rb | 3 +- lib/completions/dialects/dialect.rb | 21 +++++ lib/completions/dialects/gemini.rb | 88 ++++++++++++++----- lib/completions/dialects/llama2_classic.rb | 4 +- lib/completions/dialects/mixtral.rb | 3 +- lib/completions/dialects/orca_style.rb | 3 +- lib/completions/endpoints/base.rb | 25 +++--- lib/completions/endpoints/gemini.rb | 31 +++---- lib/completions/llm.rb | 2 + spec/lib/completions/dialects/gemini_spec.rb | 41 +++++---- spec/lib/completions/endpoints/gemini_spec.rb | 25 +++--- spec/lib/modules/ai_bot/bot_spec.rb | 1 + spec/lib/modules/ai_bot/playground_spec.rb | 42 ++------- 20 files changed, 191 insertions(+), 130 deletions(-) diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js index 5ea5e359..ff6551d1 100644 --- a/assets/javascripts/initializers/ai-bot-replies.js +++ b/assets/javascripts/initializers/ai-bot-replies.js @@ -12,7 +12,7 @@ import copyConversation from "../discourse/lib/copy-conversation"; const AUTO_COPY_THRESHOLD = 4; function isGPTBot(user) { - return user && [-110, -111, -112, -113, -114].includes(user.id); + return user && [-110, -111, -112, -113, -114, -115].includes(user.id); } function attachHeaderIcon(api) { diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index f46df038..073902b4 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -197,6 +197,7 @@ en: gpt-3: 5-turbo: "GPT-3.5" claude-2: "Claude 2" + gemini-pro: "Gemini" mixtral-8x7B-Instruct-V0: "1": "Mixtral-8x7B V0.1" sentiments: diff --git a/config/settings.yml b/config/settings.yml index 2c7a2367..8096e4cd 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -274,6 +274,7 @@ discourse_ai: - gpt-4 - gpt-4-turbo - claude-2 + - gemini-pro - mixtral-8x7B-Instruct-V0.1 ai_bot_add_to_header: default: true diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index ecdede94..6ae77e4b 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -125,6 +125,8 @@ module DiscourseAi "gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID "mistralai/Mixtral-8x7B-Instruct-v0.1" + when DiscourseAi::AiBot::EntryPoint::GEMINI_ID + "gemini-pro" else nil end @@ -146,9 +148,10 @@ module DiscourseAi #{summary}

#{details}

+ HTML - placeholder << custom_raw << "\n" if custom_raw + placeholder << custom_raw if custom_raw placeholder end diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index eaa750b3..50637240 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -10,12 +10,14 @@ module DiscourseAi CLAUDE_V2_ID = -112 GPT4_TURBO_ID = -113 MIXTRAL_ID = -114 + GEMINI_ID = -115 BOTS = [ [GPT4_ID, "gpt4_bot", "gpt-4"], [GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"], [CLAUDE_V2_ID, "claude_bot", "claude-2"], [GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"], [MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"], + [GEMINI_ID, "gemini_bot", "gemini-pro"], ] def self.map_bot_model_to_user_id(model_name) @@ -30,6 +32,8 @@ module DiscourseAi CLAUDE_V2_ID in "mixtral-8x7B-Instruct-V0.1" MIXTRAL_ID + in "gemini-pro" + GEMINI_ID else nil end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 7869cfda..61b88d07 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -37,7 +37,6 @@ module DiscourseAi result = [] - first = true context.each do |raw, username, custom_prompt| custom_prompt_translation = Proc.new do |message| @@ -51,25 +50,22 @@ module DiscourseAi custom_context[:name] = message[1] if custom_context[:type] != "assistant" - result << custom_context + custom_context end end if custom_prompt.present? - if first - custom_prompt.reverse_each(&custom_prompt_translation) - first = false - else - tool_call_and_tool = custom_prompt.first(2) - tool_call_and_tool.reverse_each(&custom_prompt_translation) - end + result << { + type: "multi_turn", + content: custom_prompt.reverse_each.map(&custom_prompt_translation).compact, + } else context = { content: raw, type: (available_bot_usernames.include?(username) ? "assistant" : "user"), } - context[:name] = username if context[:type] == "user" + context[:name] = clean_username(username) if context[:type] == "user" result << context end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 0335f53f..78eb8b98 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -65,7 +65,8 @@ module DiscourseAi def conversation_context return [] if prompt[:conversation_context].blank? - trimmed_context = trim_context(prompt[:conversation_context]) + flattened_context = flatten_context(prompt[:conversation_context]) + trimmed_context = trim_context(flattened_context) trimmed_context.reverse.map do |context| if context[:type] == "tool_call" diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index d7ecf0b0..e760ab70 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -40,7 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index b0e6c6a3..37768348 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -164,6 +164,27 @@ module DiscourseAi #{tools} TEXT end + + def flatten_context(context) + found_first_multi_turn = false + + context + .map do |a_context| + if a_context[:type] == "multi_turn" + if found_first_multi_turn + # Only take tool and tool_call_id from subsequent multi-turn interactions. + # Drop assistant responses + a_context[:content].last(2) + else + found_first_multi_turn = true + a_context[:content] + end + else + a_context + end + end + .flatten + end end end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 72f3a8ff..06d1f7df 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -15,6 +15,9 @@ module DiscourseAi end def translate + # Gemini complains if we don't alternate model/user roles. + noop_model_response = { role: "model", parts: { text: "Ok." } } + gemini_prompt = [ { role: "user", @@ -22,7 +25,7 @@ module DiscourseAi text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"), }, }, - { role: "model", parts: { text: "Ok." } }, + noop_model_response, ] if prompt[:examples] @@ -34,7 +37,13 @@ module DiscourseAi gemini_prompt.concat(conversation_context) if prompt[:conversation_context] - gemini_prompt << { role: "user", parts: { text: prompt[:input] } } + if prompt[:input] + gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user" + + gemini_prompt << { role: "user", parts: { text: prompt[:input] } } + end + + gemini_prompt end def tools @@ -42,16 +51,23 @@ module DiscourseAi translated_tools = prompt[:tools].map do |t| - required_fields = [] - tool = t.dup + tool = t.slice(:name, :description) - tool[:parameters] = t[:parameters].map do |p| - required_fields << p[:name] if p[:required] + if t[:parameters] + tool[:parameters] = t[:parameters].reduce( + { type: "object", required: [], properties: {} }, + ) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] - p.except(:required) + memo[:properties][name] = p.except(:name, :required, :item_type) + + memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] + memo + end end - tool.merge(required: required_fields) + tool end [{ function_declarations: translated_tools }] @@ -60,23 +76,42 @@ module DiscourseAi def conversation_context return [] if prompt[:conversation_context].blank? - trimmed_context = trim_context(prompt[:conversation_context]) + flattened_context = flatten_context(prompt[:conversation_context]) + trimmed_context = trim_context(flattened_context) trimmed_context.reverse.map do |context| - translated = {} - translated[:role] = (context[:type] == "user" ? "user" : "model") + if context[:type] == "tool_call" + function = JSON.parse(context[:content], symbolize_names: true) - part = {} - - if context[:type] == "tool" - part["functionResponse"] = { name: context[:name], content: context[:content] } + { + role: "model", + parts: { + functionCall: { + name: function[:name], + args: function[:arguments], + }, + }, + } + elsif context[:type] == "tool" + { + role: "function", + parts: { + functionResponse: { + name: context[:name], + response: { + content: context[:content], + }, + }, + }, + } else - part[:text] = context[:content] + { + role: context[:type] == "assistant" ? "model" : "user", + parts: { + text: context[:content], + }, + } end - - translated[:parts] = [part] - - translated end end @@ -89,6 +124,19 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + private + + def flatten_context(context) + context.map do |a_context| + if a_context[:type] == "multi_turn" + # Drop old tool calls and only keep bot response. + a_context[:content].find { |c| c[:type] == "assistant" } + else + a_context + end + end + end end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index 0a53c44b..63284c59 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -40,8 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb index 36a2fd43..4ac60d19 100644 --- a/lib/completions/dialects/mixtral.rb +++ b/lib/completions/dialects/mixtral.rb @@ -40,7 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index 74a356f3..fa742402 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -37,7 +37,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 72635e12..04d731b6 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -273,20 +273,19 @@ module DiscourseAi function_buffer.at("tool_id").inner_html = tool_name end - _read_parameters = - read_function - .at("parameters") - &.elements - .to_a - .each do |elem| - if paramenter = function_buffer.at(elem.name)&.text - function_buffer.at(elem.name).inner_html = paramenter - else - param_node = read_function.at(elem.name) - function_buffer.at("parameters").add_child(param_node) - function_buffer.at("parameters").add_child("\n") - end + read_function + .at("parameters") + &.elements + .to_a + .each do |elem| + if paramenter = function_buffer.at(elem.name)&.text + function_buffer.at(elem.name).inner_html = paramenter + else + param_node = read_function.at(elem.name) + function_buffer.at("parameters").add_child(param_node) + function_buffer.at("parameters").add_child("\n") end + end function_buffer end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 9a1e3711..f0b7a508 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -42,10 +42,12 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + tools = dialect.tools + default_options .merge(contents: prompt) .tap do |payload| - payload[:tools] = dialect.tools if dialect.tools.present? + payload[:tools] = tools if tools.present? payload[:generationConfig].merge!(model_params) if model_params.present? end end @@ -57,8 +59,12 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) - + parsed = + if @streaming_mode + response_raw + else + JSON.parse(response_raw, symbolize_names: true) + end response_h = parsed.dig(:candidates, 0, :content, :parts, 0) @has_function_call ||= response_h.dig(:functionCall).present? @@ -66,20 +72,11 @@ module DiscourseAi end def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - if line == "," - nil - elsif line.starts_with?("[") - line[1..-1] - elsif line.ends_with?("]") - line[0..-1] - else - line - end - end - .compact_blank + begin + JSON.parse(decoded_chunk, symbolize_names: true) + rescue JSON::ParserError + [] + end end def extract_prompt_for_tokenizer(prompt) diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 621e9811..e8cb14f3 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -62,6 +62,8 @@ module DiscourseAi # { type: "user", name: "user1", content: "This is a new message by a user" }, # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, # { type: "tool", name: "tool_id", content: "I'm a tool result" }, + # { type: "tool_call_id", name: "tool_id", content: { name: "tool", args: { ...tool_args } } }, + # { type: "multi_turn", content: [assistant_reply_from_a_tool, tool_call, tool_call_id] } # ] # # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example: diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 0a5266b7..cc80611b 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -98,18 +98,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(translated_context).to eq( [ { - role: "model", - parts: [ - { - "functionResponse" => { - name: context.last[:name], + role: "function", + parts: { + functionResponse: { + name: context.last[:name], + response: { content: context.last[:content], }, }, - ], + }, }, - { role: "model", parts: [{ text: context.second[:content] }] }, - { role: "user", parts: [{ text: context.first[:content] }] }, + { role: "model", parts: { text: context.second[:content] } }, + { role: "user", parts: { text: context.first[:content] } }, ], ) end @@ -121,7 +121,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do translated_context = dialect.conversation_context - expect(translated_context.last.dig(:parts, 0, :text).length).to be < + expect(translated_context.last.dig(:parts, :text).length).to be < context.last[:content].length end end @@ -133,16 +133,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { name: "get_weather", description: "Get the weather in a city", - parameters: [ - { name: "location", type: "string", description: "the city name" }, - { - name: "unit", - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], + parameters: { + type: "object", + required: %w[location unit], + properties: { + "location" => { + type: "string", + description: "the city name", + }, + "unit" => { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, }, - ], - required: %w[location unit], + }, }, ], } diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 195351e3..4e1d4b59 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -16,16 +16,21 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do { name: "get_weather", description: "Get the weather in a city", - parameters: [ - { name: "location", type: "string", description: "the city name" }, - { - name: "unit", - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], + parameters: { + type: "object", + required: %w[location unit], + properties: { + "location" => { + type: "string", + description: "the city name", + }, + "unit" => { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, }, - ], - required: %w[location unit], + }, } end @@ -126,7 +131,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do end end - chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("") + chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("") WebMock .stub_request( diff --git a/spec/lib/modules/ai_bot/bot_spec.rb b/spec/lib/modules/ai_bot/bot_spec.rb index d40f3a47..a5ac0185 100644 --- a/spec/lib/modules/ai_bot/bot_spec.rb +++ b/spec/lib/modules/ai_bot/bot_spec.rb @@ -45,6 +45,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do #{tool.summary}

+ HTML context = {} diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 0c86cbd5..255d21af 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -125,44 +125,18 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(context).to contain_exactly( *[ { type: "user", name: user.username, content: third_post.raw }, - { type: "assistant", content: custom_prompt.third.first }, - { type: "tool_call", content: custom_prompt.second.first, name: "time" }, - { type: "tool", name: "time", content: custom_prompt.first.first }, + { + type: "multi_turn", + content: [ + { type: "assistant", content: custom_prompt.third.first }, + { type: "tool_call", content: custom_prompt.second.first, name: "time" }, + { type: "tool", name: "time", content: custom_prompt.first.first }, + ], + }, { type: "user", name: user.username, content: first_post.raw }, ], ) end end - - it "include replies generated from tools only once" do - custom_prompt = [ - [ - { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, - "time", - "tool", - ], - [ - { name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json, - "time", - "tool_call", - ], - ["I replied this thanks to the time command", bot_user.username], - ] - PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) - PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt) - - context = playground.conversation_context(third_post) - - expect(context).to contain_exactly( - *[ - { type: "user", name: user.username, content: third_post.raw }, - { type: "assistant", content: custom_prompt.third.first }, - { type: "tool_call", content: custom_prompt.second.first, name: "time" }, - { type: "tool", name: "time", content: custom_prompt.first.first }, - { type: "tool_call", content: custom_prompt.second.first, name: "time" }, - { type: "tool", name: "time", content: custom_prompt.first.first }, - ], - ) - end end end