diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js
index 5ea5e359..ff6551d1 100644
--- a/assets/javascripts/initializers/ai-bot-replies.js
+++ b/assets/javascripts/initializers/ai-bot-replies.js
@@ -12,7 +12,7 @@ import copyConversation from "../discourse/lib/copy-conversation";
const AUTO_COPY_THRESHOLD = 4;
function isGPTBot(user) {
- return user && [-110, -111, -112, -113, -114].includes(user.id);
+ return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
}
function attachHeaderIcon(api) {
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index f46df038..073902b4 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -197,6 +197,7 @@ en:
gpt-3:
5-turbo: "GPT-3.5"
claude-2: "Claude 2"
+ gemini-pro: "Gemini"
mixtral-8x7B-Instruct-V0:
"1": "Mixtral-8x7B V0.1"
sentiments:
diff --git a/config/settings.yml b/config/settings.yml
index 2c7a2367..8096e4cd 100644
--- a/config/settings.yml
+++ b/config/settings.yml
@@ -274,6 +274,7 @@ discourse_ai:
- gpt-4
- gpt-4-turbo
- claude-2
+ - gemini-pro
- mixtral-8x7B-Instruct-V0.1
ai_bot_add_to_header:
default: true
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index ecdede94..6ae77e4b 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -125,6 +125,8 @@ module DiscourseAi
"gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
"mistralai/Mixtral-8x7B-Instruct-v0.1"
+ when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
+ "gemini-pro"
else
nil
end
@@ -146,9 +148,10 @@ module DiscourseAi
#{details}
+ HTML - placeholder << custom_raw << "\n" if custom_raw + placeholder << custom_raw if custom_raw placeholder end diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index eaa750b3..50637240 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -10,12 +10,14 @@ module DiscourseAi CLAUDE_V2_ID = -112 GPT4_TURBO_ID = -113 MIXTRAL_ID = -114 + GEMINI_ID = -115 BOTS = [ [GPT4_ID, "gpt4_bot", "gpt-4"], [GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"], [CLAUDE_V2_ID, "claude_bot", "claude-2"], [GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"], [MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"], + [GEMINI_ID, "gemini_bot", "gemini-pro"], ] def self.map_bot_model_to_user_id(model_name) @@ -30,6 +32,8 @@ module DiscourseAi CLAUDE_V2_ID in "mixtral-8x7B-Instruct-V0.1" MIXTRAL_ID + in "gemini-pro" + GEMINI_ID else nil end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 7869cfda..61b88d07 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -37,7 +37,6 @@ module DiscourseAi result = [] - first = true context.each do |raw, username, custom_prompt| custom_prompt_translation = Proc.new do |message| @@ -51,25 +50,22 @@ module DiscourseAi custom_context[:name] = message[1] if custom_context[:type] != "assistant" - result << custom_context + custom_context end end if custom_prompt.present? - if first - custom_prompt.reverse_each(&custom_prompt_translation) - first = false - else - tool_call_and_tool = custom_prompt.first(2) - tool_call_and_tool.reverse_each(&custom_prompt_translation) - end + result << { + type: "multi_turn", + content: custom_prompt.reverse_each.map(&custom_prompt_translation).compact, + } else context = { content: raw, type: (available_bot_usernames.include?(username) ? "assistant" : "user"), } - context[:name] = username if context[:type] == "user" + context[:name] = clean_username(username) if context[:type] == "user" result << context end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 0335f53f..78eb8b98 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -65,7 +65,8 @@ module DiscourseAi def conversation_context return [] if prompt[:conversation_context].blank? - trimmed_context = trim_context(prompt[:conversation_context]) + flattened_context = flatten_context(prompt[:conversation_context]) + trimmed_context = trim_context(flattened_context) trimmed_context.reverse.map do |context| if context[:type] == "tool_call" diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index d7ecf0b0..e760ab70 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -40,7 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index b0e6c6a3..37768348 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -164,6 +164,27 @@ module DiscourseAi #{tools} TEXT end + + def flatten_context(context) + found_first_multi_turn = false + + context + .map do |a_context| + if a_context[:type] == "multi_turn" + if found_first_multi_turn + # Only take tool and tool_call_id from subsequent multi-turn interactions. + # Drop assistant responses + a_context[:content].last(2) + else + found_first_multi_turn = true + a_context[:content] + end + else + a_context + end + end + .flatten + end end end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 72f3a8ff..06d1f7df 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -15,6 +15,9 @@ module DiscourseAi end def translate + # Gemini complains if we don't alternate model/user roles. + noop_model_response = { role: "model", parts: { text: "Ok." } } + gemini_prompt = [ { role: "user", @@ -22,7 +25,7 @@ module DiscourseAi text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"), }, }, - { role: "model", parts: { text: "Ok." } }, + noop_model_response, ] if prompt[:examples] @@ -34,7 +37,13 @@ module DiscourseAi gemini_prompt.concat(conversation_context) if prompt[:conversation_context] - gemini_prompt << { role: "user", parts: { text: prompt[:input] } } + if prompt[:input] + gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user" + + gemini_prompt << { role: "user", parts: { text: prompt[:input] } } + end + + gemini_prompt end def tools @@ -42,16 +51,23 @@ module DiscourseAi translated_tools = prompt[:tools].map do |t| - required_fields = [] - tool = t.dup + tool = t.slice(:name, :description) - tool[:parameters] = t[:parameters].map do |p| - required_fields << p[:name] if p[:required] + if t[:parameters] + tool[:parameters] = t[:parameters].reduce( + { type: "object", required: [], properties: {} }, + ) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] - p.except(:required) + memo[:properties][name] = p.except(:name, :required, :item_type) + + memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] + memo + end end - tool.merge(required: required_fields) + tool end [{ function_declarations: translated_tools }] @@ -60,23 +76,42 @@ module DiscourseAi def conversation_context return [] if prompt[:conversation_context].blank? - trimmed_context = trim_context(prompt[:conversation_context]) + flattened_context = flatten_context(prompt[:conversation_context]) + trimmed_context = trim_context(flattened_context) trimmed_context.reverse.map do |context| - translated = {} - translated[:role] = (context[:type] == "user" ? "user" : "model") + if context[:type] == "tool_call" + function = JSON.parse(context[:content], symbolize_names: true) - part = {} - - if context[:type] == "tool" - part["functionResponse"] = { name: context[:name], content: context[:content] } + { + role: "model", + parts: { + functionCall: { + name: function[:name], + args: function[:arguments], + }, + }, + } + elsif context[:type] == "tool" + { + role: "function", + parts: { + functionResponse: { + name: context[:name], + response: { + content: context[:content], + }, + }, + }, + } else - part[:text] = context[:content] + { + role: context[:type] == "assistant" ? "model" : "user", + parts: { + text: context[:content], + }, + } end - - translated[:parts] = [part] - - translated end end @@ -89,6 +124,19 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + private + + def flatten_context(context) + context.map do |a_context| + if a_context[:type] == "multi_turn" + # Drop old tool calls and only keep bot response. + a_context[:content].find { |c| c[:type] == "assistant" } + else + a_context + end + end + end end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index 0a53c44b..63284c59 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -40,8 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb index 36a2fd43..4ac60d19 100644 --- a/lib/completions/dialects/mixtral.rb +++ b/lib/completions/dialects/mixtral.rb @@ -40,7 +40,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index 74a356f3..fa742402 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -37,7 +37,8 @@ module DiscourseAi return "" if prompt[:conversation_context].blank? clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } - trimmed_context = trim_context(clean_context) + flattened_context = flatten_context(clean_context) + trimmed_context = trim_context(flattened_context) trimmed_context .reverse diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 72635e12..04d731b6 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -273,20 +273,19 @@ module DiscourseAi function_buffer.at("tool_id").inner_html = tool_name end - _read_parameters = - read_function - .at("parameters") - &.elements - .to_a - .each do |elem| - if paramenter = function_buffer.at(elem.name)&.text - function_buffer.at(elem.name).inner_html = paramenter - else - param_node = read_function.at(elem.name) - function_buffer.at("parameters").add_child(param_node) - function_buffer.at("parameters").add_child("\n") - end + read_function + .at("parameters") + &.elements + .to_a + .each do |elem| + if paramenter = function_buffer.at(elem.name)&.text + function_buffer.at(elem.name).inner_html = paramenter + else + param_node = read_function.at(elem.name) + function_buffer.at("parameters").add_child(param_node) + function_buffer.at("parameters").add_child("\n") end + end function_buffer end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 9a1e3711..f0b7a508 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -42,10 +42,12 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + tools = dialect.tools + default_options .merge(contents: prompt) .tap do |payload| - payload[:tools] = dialect.tools if dialect.tools.present? + payload[:tools] = tools if tools.present? payload[:generationConfig].merge!(model_params) if model_params.present? end end @@ -57,8 +59,12 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) - + parsed = + if @streaming_mode + response_raw + else + JSON.parse(response_raw, symbolize_names: true) + end response_h = parsed.dig(:candidates, 0, :content, :parts, 0) @has_function_call ||= response_h.dig(:functionCall).present? @@ -66,20 +72,11 @@ module DiscourseAi end def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - if line == "," - nil - elsif line.starts_with?("[") - line[1..-1] - elsif line.ends_with?("]") - line[0..-1] - else - line - end - end - .compact_blank + begin + JSON.parse(decoded_chunk, symbolize_names: true) + rescue JSON::ParserError + [] + end end def extract_prompt_for_tokenizer(prompt) diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 621e9811..e8cb14f3 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -62,6 +62,8 @@ module DiscourseAi # { type: "user", name: "user1", content: "This is a new message by a user" }, # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, # { type: "tool", name: "tool_id", content: "I'm a tool result" }, + # { type: "tool_call_id", name: "tool_id", content: { name: "tool", args: { ...tool_args } } }, + # { type: "multi_turn", content: [assistant_reply_from_a_tool, tool_call, tool_call_id] } # ] # # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example: diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 0a5266b7..cc80611b 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -98,18 +98,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(translated_context).to eq( [ { - role: "model", - parts: [ - { - "functionResponse" => { - name: context.last[:name], + role: "function", + parts: { + functionResponse: { + name: context.last[:name], + response: { content: context.last[:content], }, }, - ], + }, }, - { role: "model", parts: [{ text: context.second[:content] }] }, - { role: "user", parts: [{ text: context.first[:content] }] }, + { role: "model", parts: { text: context.second[:content] } }, + { role: "user", parts: { text: context.first[:content] } }, ], ) end @@ -121,7 +121,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do translated_context = dialect.conversation_context - expect(translated_context.last.dig(:parts, 0, :text).length).to be < + expect(translated_context.last.dig(:parts, :text).length).to be < context.last[:content].length end end @@ -133,16 +133,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { name: "get_weather", description: "Get the weather in a city", - parameters: [ - { name: "location", type: "string", description: "the city name" }, - { - name: "unit", - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], + parameters: { + type: "object", + required: %w[location unit], + properties: { + "location" => { + type: "string", + description: "the city name", + }, + "unit" => { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, }, - ], - required: %w[location unit], + }, }, ], } diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 195351e3..4e1d4b59 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -16,16 +16,21 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do { name: "get_weather", description: "Get the weather in a city", - parameters: [ - { name: "location", type: "string", description: "the city name" }, - { - name: "unit", - type: "string", - description: "the unit of measurement celcius c or fahrenheit f", - enum: %w[c f], + parameters: { + type: "object", + required: %w[location unit], + properties: { + "location" => { + type: "string", + description: "the city name", + }, + "unit" => { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, }, - ], - required: %w[location unit], + }, } end @@ -126,7 +131,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do end end - chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("") + chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("") WebMock .stub_request( diff --git a/spec/lib/modules/ai_bot/bot_spec.rb b/spec/lib/modules/ai_bot/bot_spec.rb index d40f3a47..a5ac0185 100644 --- a/spec/lib/modules/ai_bot/bot_spec.rb +++ b/spec/lib/modules/ai_bot/bot_spec.rb @@ -45,6 +45,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do