FEATURE: AI Bot Gemini support. (#402)

It also corrects the syntax around tool support, which was wrong.

Gemini doesn't want us to include messages about previous tool invocations, so I had to shuffle around some code to send the response it generated from those invocations instead. For this, I created the "multi_turn" context, which bundles all the context involved in the interaction.
This commit is contained in:
Roman Rizzi 2024-01-04 18:15:34 -03:00 committed by GitHub
parent aa56baad37
commit 971e03bdf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 191 additions and 130 deletions

View File

@ -12,7 +12,7 @@ import copyConversation from "../discourse/lib/copy-conversation";
const AUTO_COPY_THRESHOLD = 4;
function isGPTBot(user) {
return user && [-110, -111, -112, -113, -114].includes(user.id);
return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
}
function attachHeaderIcon(api) {

View File

@ -197,6 +197,7 @@ en:
gpt-3:
5-turbo: "GPT-3.5"
claude-2: "Claude 2"
gemini-pro: "Gemini"
mixtral-8x7B-Instruct-V0:
"1": "Mixtral-8x7B V0.1"
sentiments:

View File

@ -274,6 +274,7 @@ discourse_ai:
- gpt-4
- gpt-4-turbo
- claude-2
- gemini-pro
- mixtral-8x7B-Instruct-V0.1
ai_bot_add_to_header:
default: true

View File

@ -125,6 +125,8 @@ module DiscourseAi
"gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
"mistralai/Mixtral-8x7B-Instruct-v0.1"
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"gemini-pro"
else
nil
end
@ -146,9 +148,10 @@ module DiscourseAi
<summary>#{summary}</summary>
<p>#{details}</p>
</details>
HTML
placeholder << custom_raw << "\n" if custom_raw
placeholder << custom_raw if custom_raw
placeholder
end

View File

@ -10,12 +10,14 @@ module DiscourseAi
CLAUDE_V2_ID = -112
GPT4_TURBO_ID = -113
MIXTRAL_ID = -114
GEMINI_ID = -115
BOTS = [
[GPT4_ID, "gpt4_bot", "gpt-4"],
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
[CLAUDE_V2_ID, "claude_bot", "claude-2"],
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
[GEMINI_ID, "gemini_bot", "gemini-pro"],
]
def self.map_bot_model_to_user_id(model_name)
@ -30,6 +32,8 @@ module DiscourseAi
CLAUDE_V2_ID
in "mixtral-8x7B-Instruct-V0.1"
MIXTRAL_ID
in "gemini-pro"
GEMINI_ID
else
nil
end

View File

@ -37,7 +37,6 @@ module DiscourseAi
result = []
first = true
context.each do |raw, username, custom_prompt|
custom_prompt_translation =
Proc.new do |message|
@ -51,25 +50,22 @@ module DiscourseAi
custom_context[:name] = message[1] if custom_context[:type] != "assistant"
result << custom_context
custom_context
end
end
if custom_prompt.present?
if first
custom_prompt.reverse_each(&custom_prompt_translation)
first = false
else
tool_call_and_tool = custom_prompt.first(2)
tool_call_and_tool.reverse_each(&custom_prompt_translation)
end
result << {
type: "multi_turn",
content: custom_prompt.reverse_each.map(&custom_prompt_translation).compact,
}
else
context = {
content: raw,
type: (available_bot_usernames.include?(username) ? "assistant" : "user"),
}
context[:name] = username if context[:type] == "user"
context[:name] = clean_username(username) if context[:type] == "user"
result << context
end

View File

@ -65,7 +65,8 @@ module DiscourseAi
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
flattened_context = flatten_context(prompt[:conversation_context])
trimmed_context = trim_context(flattened_context)
trimmed_context.reverse.map do |context|
if context[:type] == "tool_call"

View File

@ -40,7 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context
.reverse

View File

@ -164,6 +164,27 @@ module DiscourseAi
#{tools}</tools>
TEXT
end
def flatten_context(context)
found_first_multi_turn = false
context
.map do |a_context|
if a_context[:type] == "multi_turn"
if found_first_multi_turn
# Only take tool and tool_call_id from subsequent multi-turn interactions.
# Drop assistant responses
a_context[:content].last(2)
else
found_first_multi_turn = true
a_context[:content]
end
else
a_context
end
end
.flatten
end
end
end
end

View File

@ -15,6 +15,9 @@ module DiscourseAi
end
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
gemini_prompt = [
{
role: "user",
@ -22,7 +25,7 @@ module DiscourseAi
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
noop_model_response,
]
if prompt[:examples]
@ -34,7 +37,13 @@ module DiscourseAi
gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
if prompt[:input]
gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user"
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
gemini_prompt
end
def tools
@ -42,16 +51,23 @@ module DiscourseAi
translated_tools =
prompt[:tools].map do |t|
required_fields = []
tool = t.dup
tool = t.slice(:name, :description)
tool[:parameters] = t[:parameters].map do |p|
required_fields << p[:name] if p[:required]
if t[:parameters]
tool[:parameters] = t[:parameters].reduce(
{ type: "object", required: [], properties: {} },
) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
p.except(:required)
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
end
tool.merge(required: required_fields)
tool
end
[{ function_declarations: translated_tools }]
@ -60,23 +76,42 @@ module DiscourseAi
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
flattened_context = flatten_context(prompt[:conversation_context])
trimmed_context = trim_context(flattened_context)
trimmed_context.reverse.map do |context|
translated = {}
translated[:role] = (context[:type] == "user" ? "user" : "model")
if context[:type] == "tool_call"
function = JSON.parse(context[:content], symbolize_names: true)
part = {}
if context[:type] == "tool"
part["functionResponse"] = { name: context[:name], content: context[:content] }
{
role: "model",
parts: {
functionCall: {
name: function[:name],
args: function[:arguments],
},
},
}
elsif context[:type] == "tool"
{
role: "function",
parts: {
functionResponse: {
name: context[:name],
response: {
content: context[:content],
},
},
},
}
else
part[:text] = context[:content]
{
role: context[:type] == "assistant" ? "model" : "user",
parts: {
text: context[:content],
},
}
end
translated[:parts] = [part]
translated
end
end
@ -89,6 +124,19 @@ module DiscourseAi
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
private
def flatten_context(context)
context.map do |a_context|
if a_context[:type] == "multi_turn"
# Drop old tool calls and only keep bot response.
a_context[:content].find { |c| c[:type] == "assistant" }
else
a_context
end
end
end
end
end
end

View File

@ -40,8 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context
.reverse

View File

@ -40,7 +40,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context
.reverse

View File

@ -37,7 +37,8 @@ module DiscourseAi
return "" if prompt[:conversation_context].blank?
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
trimmed_context = trim_context(clean_context)
flattened_context = flatten_context(clean_context)
trimmed_context = trim_context(flattened_context)
trimmed_context
.reverse

View File

@ -273,20 +273,19 @@ module DiscourseAi
function_buffer.at("tool_id").inner_html = tool_name
end
_read_parameters =
read_function
.at("parameters")
&.elements
.to_a
.each do |elem|
if paramenter = function_buffer.at(elem.name)&.text
function_buffer.at(elem.name).inner_html = paramenter
else
param_node = read_function.at(elem.name)
function_buffer.at("parameters").add_child(param_node)
function_buffer.at("parameters").add_child("\n")
end
read_function
.at("parameters")
&.elements
.to_a
.each do |elem|
if paramenter = function_buffer.at(elem.name)&.text
function_buffer.at(elem.name).inner_html = paramenter
else
param_node = read_function.at(elem.name)
function_buffer.at("parameters").add_child(param_node)
function_buffer.at("parameters").add_child("\n")
end
end
function_buffer
end

View File

@ -42,10 +42,12 @@ module DiscourseAi
end
def prepare_payload(prompt, model_params, dialect)
tools = dialect.tools
default_options
.merge(contents: prompt)
.tap do |payload|
payload[:tools] = dialect.tools if dialect.tools.present?
payload[:tools] = tools if tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
end
end
@ -57,8 +59,12 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
parsed =
if @streaming_mode
response_raw
else
JSON.parse(response_raw, symbolize_names: true)
end
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
@has_function_call ||= response_h.dig(:functionCall).present?
@ -66,20 +72,11 @@ module DiscourseAi
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
if line == ","
nil
elsif line.starts_with?("[")
line[1..-1]
elsif line.ends_with?("]")
line[0..-1]
else
line
end
end
.compact_blank
begin
JSON.parse(decoded_chunk, symbolize_names: true)
rescue JSON::ParserError
[]
end
end
def extract_prompt_for_tokenizer(prompt)

View File

@ -62,6 +62,8 @@ module DiscourseAi
# { type: "user", name: "user1", content: "This is a new message by a user" },
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
# { type: "tool_call_id", name: "tool_id", content: { name: "tool", args: { ...tool_args } } },
# { type: "multi_turn", content: [assistant_reply_from_a_tool, tool_call, tool_call_id] }
# ]
#
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:

View File

@ -98,18 +98,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(translated_context).to eq(
[
{
role: "model",
parts: [
{
"functionResponse" => {
name: context.last[:name],
role: "function",
parts: {
functionResponse: {
name: context.last[:name],
response: {
content: context.last[:content],
},
},
],
},
},
{ role: "model", parts: [{ text: context.second[:content] }] },
{ role: "user", parts: [{ text: context.first[:content] }] },
{ role: "model", parts: { text: context.second[:content] } },
{ role: "user", parts: { text: context.first[:content] } },
],
)
end
@ -121,7 +121,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
translated_context = dialect.conversation_context
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
expect(translated_context.last.dig(:parts, :text).length).to be <
context.last[:content].length
end
end
@ -133,16 +133,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
],
required: %w[location unit],
},
},
],
}

View File

@ -16,16 +16,21 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
parameters: {
type: "object",
required: %w[location unit],
properties: {
"location" => {
type: "string",
description: "the city name",
},
"unit" => {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
],
required: %w[location unit],
},
}
end
@ -126,7 +131,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
end
end
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("")
WebMock
.stub_request(

View File

@ -45,6 +45,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
<summary>#{tool.summary}</summary>
<p></p>
</details>
HTML
context = {}

View File

@ -125,44 +125,18 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
{
type: "multi_turn",
content: [
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
],
},
{ type: "user", name: user.username, content: first_post.raw },
],
)
end
end
it "include replies generated from tools only once" do
custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json,
"time",
"tool_call",
],
["I replied this thanks to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: "user", name: user.username, content: third_post.raw },
{ type: "assistant", content: custom_prompt.third.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
{ type: "tool", name: "time", content: custom_prompt.first.first },
],
)
end
end
end