FEATURE: AI Bot Gemini support. (#402)

It also corrects the syntax around tool support, which was wrong.

Gemini doesn't want us to include messages about previous tool invocations, so I had to shuffle around some code to send the response it generated from those invocations instead. For this, I created the "multi_turn" context, which bundles all the context involved in the interaction.
This commit is contained in:
Roman Rizzi 2024-01-04 18:15:34 -03:00 committed by GitHub
parent aa56baad37
commit 971e03bdf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 191 additions and 130 deletions

View File

@ -12,7 +12,7 @@ import copyConversation from "../discourse/lib/copy-conversation";
const AUTO_COPY_THRESHOLD = 4; const AUTO_COPY_THRESHOLD = 4;
function isGPTBot(user) { function isGPTBot(user) {
return user && [-110, -111, -112, -113, -114].includes(user.id); return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
} }
function attachHeaderIcon(api) { function attachHeaderIcon(api) {

View File

@ -197,6 +197,7 @@ en:
gpt-3: gpt-3:
5-turbo: "GPT-3.5" 5-turbo: "GPT-3.5"
claude-2: "Claude 2" claude-2: "Claude 2"
gemini-pro: "Gemini"
mixtral-8x7B-Instruct-V0: mixtral-8x7B-Instruct-V0:
"1": "Mixtral-8x7B V0.1" "1": "Mixtral-8x7B V0.1"
sentiments: sentiments:

View File

@ -274,6 +274,7 @@ discourse_ai:
- gpt-4 - gpt-4
- gpt-4-turbo - gpt-4-turbo
- claude-2 - claude-2
- gemini-pro
- mixtral-8x7B-Instruct-V0.1 - mixtral-8x7B-Instruct-V0.1
ai_bot_add_to_header: ai_bot_add_to_header:
default: true default: true

View File

@ -125,6 +125,8 @@ module DiscourseAi
"gpt-3.5-turbo-16k" "gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
"mistralai/Mixtral-8x7B-Instruct-v0.1" "mistralai/Mixtral-8x7B-Instruct-v0.1"
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"gemini-pro"
else else
nil nil
end end
@ -146,9 +148,10 @@ module DiscourseAi
<summary>#{summary}</summary> <summary>#{summary}</summary>
<p>#{details}</p> <p>#{details}</p>
</details> </details>
HTML HTML
placeholder << custom_raw << "\n" if custom_raw placeholder << custom_raw if custom_raw
placeholder placeholder
end end

View File

@ -10,12 +10,14 @@ module DiscourseAi
CLAUDE_V2_ID = -112 CLAUDE_V2_ID = -112
GPT4_TURBO_ID = -113 GPT4_TURBO_ID = -113
MIXTRAL_ID = -114 MIXTRAL_ID = -114
GEMINI_ID = -115
BOTS = [ BOTS = [
[GPT4_ID, "gpt4_bot", "gpt-4"], [GPT4_ID, "gpt4_bot", "gpt-4"],
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"], [GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
[CLAUDE_V2_ID, "claude_bot", "claude-2"], [CLAUDE_V2_ID, "claude_bot", "claude-2"],
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"], [GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"], [MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
[GEMINI_ID, "gemini_bot", "gemini-pro"],
] ]
def self.map_bot_model_to_user_id(model_name) def self.map_bot_model_to_user_id(model_name)
@ -30,6 +32,8 @@ module DiscourseAi
CLAUDE_V2_ID CLAUDE_V2_ID
in "mixtral-8x7B-Instruct-V0.1" in "mixtral-8x7B-Instruct-V0.1"
MIXTRAL_ID MIXTRAL_ID
in "gemini-pro"
GEMINI_ID
else else
nil nil
end end

View File

@ -37,7 +37,6 @@ module DiscourseAi
result = [] result = []
first = true
context.each do |raw, username, custom_prompt| context.each do |raw, username, custom_prompt|
custom_prompt_translation = custom_prompt_translation =
Proc.new do |message| Proc.new do |message|
@ -51,25 +50,22 @@ module DiscourseAi
custom_context[:name] = message[1] if custom_context[:type] != "assistant" custom_context[:name] = message[1] if custom_context[:type] != "assistant"
result << custom_context custom_context
end end
end end
if custom_prompt.present? if custom_prompt.present?
if first result << {
custom_prompt.reverse_each(&custom_prompt_translation) type: "multi_turn",
first = false content: custom_prompt.reverse_each.map(&custom_prompt_translation).compact,
else }
tool_call_and_tool = custom_prompt.first(2)
tool_call_and_tool.reverse_each(&custom_prompt_translation)
end
else else
context = { context = {
content: raw, content: raw,
type: (available_bot_usernames.include?(username) ? "assistant" : "user"), type: (available_bot_usernames.include?(username) ? "assistant" : "user"),
} }
context[:name] = username if context[:type] == "user" context[:name] = clean_username(username) if context[:type] == "user"
result << context result << context
end end

View File

@ -65,7 +65,8 @@ module DiscourseAi
def conversation_context def conversation_context
return [] if prompt[:conversation_context].blank? return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) flattened_context = flatten_context(prompt[:conversation_context])
trimmed_context = trim_context(flattened_context)
trimmed_context.reverse.map do |context| trimmed_context.reverse.map do |context|
if context[:type] == "tool_call" if context[:type] == "tool_call"

View File

@ -40,7 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context) flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context trimmed_context
.reverse .reverse

View File

@ -164,6 +164,27 @@ module DiscourseAi
#{tools}</tools> #{tools}</tools>
TEXT TEXT
end end
def flatten_context(context)
found_first_multi_turn = false
context
.map do |a_context|
if a_context[:type] == "multi_turn"
if found_first_multi_turn
# Only take tool and tool_call_id from subsequent multi-turn interactions.
# Drop assistant responses
a_context[:content].last(2)
else
found_first_multi_turn = true
a_context[:content]
end
else
a_context
end
end
.flatten
end
end end
end end
end end

View File

@ -15,6 +15,9 @@ module DiscourseAi
end end
def translate def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
gemini_prompt = [ gemini_prompt = [
{ {
role: "user", role: "user",
@ -22,7 +25,7 @@ module DiscourseAi
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"), text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
}, },
}, },
{ role: "model", parts: { text: "Ok." } }, noop_model_response,
] ]
if prompt[:examples] if prompt[:examples]
@ -34,7 +37,13 @@ module DiscourseAi
gemini_prompt.concat(conversation_context) if prompt[:conversation_context] gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } } if prompt[:input]
gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user"
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
gemini_prompt
end end
def tools def tools
@ -42,16 +51,23 @@ module DiscourseAi
translated_tools = translated_tools =
prompt[:tools].map do |t| prompt[:tools].map do |t|
required_fields = [] tool = t.slice(:name, :description)
tool = t.dup
tool[:parameters] = t[:parameters].map do |p| if t[:parameters]
required_fields << p[:name] if p[:required] tool[:parameters] = t[:parameters].reduce(
{ type: "object", required: [], properties: {} },
) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
p.except(:required) memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
end end
tool.merge(required: required_fields) tool
end end
[{ function_declarations: translated_tools }] [{ function_declarations: translated_tools }]
@ -60,23 +76,42 @@ module DiscourseAi
def conversation_context def conversation_context
return [] if prompt[:conversation_context].blank? return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context]) flattened_context = flatten_context(prompt[:conversation_context])
trimmed_context = trim_context(flattened_context)
trimmed_context.reverse.map do |context| trimmed_context.reverse.map do |context|
translated = {} if context[:type] == "tool_call"
translated[:role] = (context[:type] == "user" ? "user" : "model") function = JSON.parse(context[:content], symbolize_names: true)
part = {} {
role: "model",
if context[:type] == "tool" parts: {
part["functionResponse"] = { name: context[:name], content: context[:content] } functionCall: {
name: function[:name],
args: function[:arguments],
},
},
}
elsif context[:type] == "tool"
{
role: "function",
parts: {
functionResponse: {
name: context[:name],
response: {
content: context[:content],
},
},
},
}
else else
part[:text] = context[:content] {
role: context[:type] == "assistant" ? "model" : "user",
parts: {
text: context[:content],
},
}
end end
translated[:parts] = [part]
translated
end end
end end
@ -89,6 +124,19 @@ module DiscourseAi
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
private
def flatten_context(context)
context.map do |a_context|
if a_context[:type] == "multi_turn"
# Drop old tool calls and only keep bot response.
a_context[:content].find { |c| c[:type] == "assistant" }
else
a_context
end
end
end
end end
end end
end end

View File

@ -40,8 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(clean_context) trimmed_context = trim_context(flattened_context)
trimmed_context trimmed_context
.reverse .reverse

View File

@ -40,7 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context) flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context trimmed_context
.reverse .reverse

View File

@ -37,7 +37,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank? return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" } clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context) flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context trimmed_context
.reverse .reverse

View File

@ -273,20 +273,19 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name function_buffer.at("tool_id").inner_html = tool_name
end end
_read_parameters = read_function
read_function .at("parameters")
.at("parameters") &.elements
&.elements .to_a
.to_a .each do |elem|
.each do |elem| if paramenter = function_buffer.at(elem.name)&.text
if paramenter = function_buffer.at(elem.name)&.text function_buffer.at(elem.name).inner_html = paramenter
function_buffer.at(elem.name).inner_html = paramenter else
else param_node = read_function.at(elem.name)
param_node = read_function.at(elem.name) function_buffer.at("parameters").add_child(param_node)
function_buffer.at("parameters").add_child(param_node) function_buffer.at("parameters").add_child("\n")
function_buffer.at("parameters").add_child("\n")
end
end end
end
function_buffer function_buffer
end end

View File

@ -42,10 +42,12 @@ module DiscourseAi
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
tools = dialect.tools
default_options default_options
.merge(contents: prompt) .merge(contents: prompt)
.tap do |payload| .tap do |payload|
payload[:tools] = dialect.tools if dialect.tools.present? payload[:tools] = tools if tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present? payload[:generationConfig].merge!(model_params) if model_params.present?
end end
end end
@ -57,8 +59,12 @@ module DiscourseAi
end end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true) parsed =
if @streaming_mode
response_raw
else
JSON.parse(response_raw, symbolize_names: true)
end
response_h = parsed.dig(:candidates, 0, :content, :parts, 0) response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
@has_function_call ||= response_h.dig(:functionCall).present? @has_function_call ||= response_h.dig(:functionCall).present?
@ -66,20 +72,11 @@ module DiscourseAi
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
decoded_chunk begin
.split("\n") JSON.parse(decoded_chunk, symbolize_names: true)
.map do |line| rescue JSON::ParserError
if line == "," []
nil end
elsif line.starts_with?("[")
line[1..-1]
elsif line.ends_with?("]")
line[0..-1]
else
line
end
end
.compact_blank
end end
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)

View File

@ -62,6 +62,8 @@ module DiscourseAi
# { type: "user", name: "user1", content: "This is a new message by a user" }, # { type: "user", name: "user1", content: "This is a new message by a user" },
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
# { type: "tool", name: "tool_id", content: "I'm a tool result" }, # { type: "tool", name: "tool_id", content: "I'm a tool result" },
# { type: "tool_call_id", name: "tool_id", content: { name: "tool", args: { ...tool_args } } },
# { type: "multi_turn", content: [assistant_reply_from_a_tool, tool_call, tool_call_id] }
# ] # ]
# #
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example: # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:

View File

@ -98,18 +98,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(translated_context).to eq( expect(translated_context).to eq(
[ [
{ {
role: "model", role: "function",
parts: [ parts: {
{ functionResponse: {
"functionResponse" => { name: context.last[:name],
name: context.last[:name], response: {
content: context.last[:content], content: context.last[:content],
}, },
}, },
], },
}, },
{ role: "model", parts: [{ text: context.second[:content] }] }, { role: "model", parts: { text: context.second[:content] } },
{ role: "user", parts: [{ text: context.first[:content] }] }, { role: "user", parts: { text: context.first[:content] } },
], ],
) )
end end
@ -121,7 +121,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
translated_context = dialect.conversation_context translated_context = dialect.conversation_context
expect(translated_context.last.dig(:parts, 0, :text).length).to be < expect(translated_context.last.dig(:parts, :text).length).to be <
context.last[:content].length context.last[:content].length
end end
end end
@ -133,16 +133,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ {
name: "get_weather", name: "get_weather",
description: "Get the weather in a city", description: "Get the weather in a city",
parameters: [ parameters: {
{ name: "location", type: "string", description: "the city name" }, type: "object",
{ required: %w[location unit],
name: "unit", properties: {
type: "string", "location" => {
description: "the unit of measurement celcius c or fahrenheit f", type: "string",
enum: %w[c f], description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
}, },
], },
required: %w[location unit],
}, },
], ],
} }

View File

@ -16,16 +16,21 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{ {
name: "get_weather", name: "get_weather",
description: "Get the weather in a city", description: "Get the weather in a city",
parameters: [ parameters: {
{ name: "location", type: "string", description: "the city name" }, type: "object",
{ required: %w[location unit],
name: "unit", properties: {
type: "string", "location" => {
description: "the unit of measurement celcius c or fahrenheit f", type: "string",
enum: %w[c f], description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
}, },
], },
required: %w[location unit],
} }
end end
@ -126,7 +131,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
end end
end end
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("") chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("")
WebMock WebMock
.stub_request( .stub_request(

View File

@ -45,6 +45,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
<summary>#{tool.summary}</summary> <summary>#{tool.summary}</summary>
<p></p> <p></p>
</details> </details>
HTML HTML
context = {} context = {}

View File

@ -125,44 +125,18 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(context).to contain_exactly( expect(context).to contain_exactly(
*[ *[
{ type: "user", name: user.username, content: third_post.raw }, { type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first }, {
{ type: "tool_call", content: custom_prompt.second.first, name: "time" }, type: "multi_turn",
{ type: "tool", name: "time", content: custom_prompt.first.first }, content: [
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
],
},
{ type: "user", name: user.username, content: first_post.raw }, { type: "user", name: user.username, content: first_post.raw },
], ],
) )
end end
end end
it "include replies generated from tools only once" do
custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json,
"time",
"tool_call",
],
["I replied this thanks to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
],
)
end
end end
end end