FEATURE: AI Bot Gemini support. (#402)
It also corrects the syntax around tool support, which was wrong. Gemini doesn't want us to include messages about previous tool invocations, so I had to shuffle around some code to send the response it generated from those invocations instead. For this, I created the "multi_turn" context, which bundles all the context involved in the interaction.
This commit is contained in:
parent
aa56baad37
commit
971e03bdf2
|
@ -12,7 +12,7 @@ import copyConversation from "../discourse/lib/copy-conversation";
|
|||
const AUTO_COPY_THRESHOLD = 4;
|
||||
|
||||
function isGPTBot(user) {
|
||||
return user && [-110, -111, -112, -113, -114].includes(user.id);
|
||||
return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
|
||||
}
|
||||
|
||||
function attachHeaderIcon(api) {
|
||||
|
|
|
@ -197,6 +197,7 @@ en:
|
|||
gpt-3:
|
||||
5-turbo: "GPT-3.5"
|
||||
claude-2: "Claude 2"
|
||||
gemini-pro: "Gemini"
|
||||
mixtral-8x7B-Instruct-V0:
|
||||
"1": "Mixtral-8x7B V0.1"
|
||||
sentiments:
|
||||
|
|
|
@ -274,6 +274,7 @@ discourse_ai:
|
|||
- gpt-4
|
||||
- gpt-4-turbo
|
||||
- claude-2
|
||||
- gemini-pro
|
||||
- mixtral-8x7B-Instruct-V0.1
|
||||
ai_bot_add_to_header:
|
||||
default: true
|
||||
|
|
|
@ -125,6 +125,8 @@ module DiscourseAi
|
|||
"gpt-3.5-turbo-16k"
|
||||
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||
"gemini-pro"
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
@ -146,9 +148,10 @@ module DiscourseAi
|
|||
<summary>#{summary}</summary>
|
||||
<p>#{details}</p>
|
||||
</details>
|
||||
|
||||
HTML
|
||||
|
||||
placeholder << custom_raw << "\n" if custom_raw
|
||||
placeholder << custom_raw if custom_raw
|
||||
|
||||
placeholder
|
||||
end
|
||||
|
|
|
@ -10,12 +10,14 @@ module DiscourseAi
|
|||
CLAUDE_V2_ID = -112
|
||||
GPT4_TURBO_ID = -113
|
||||
MIXTRAL_ID = -114
|
||||
GEMINI_ID = -115
|
||||
BOTS = [
|
||||
[GPT4_ID, "gpt4_bot", "gpt-4"],
|
||||
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
|
||||
[CLAUDE_V2_ID, "claude_bot", "claude-2"],
|
||||
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
|
||||
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
|
||||
[GEMINI_ID, "gemini_bot", "gemini-pro"],
|
||||
]
|
||||
|
||||
def self.map_bot_model_to_user_id(model_name)
|
||||
|
@ -30,6 +32,8 @@ module DiscourseAi
|
|||
CLAUDE_V2_ID
|
||||
in "mixtral-8x7B-Instruct-V0.1"
|
||||
MIXTRAL_ID
|
||||
in "gemini-pro"
|
||||
GEMINI_ID
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
|
|
@ -37,7 +37,6 @@ module DiscourseAi
|
|||
|
||||
result = []
|
||||
|
||||
first = true
|
||||
context.each do |raw, username, custom_prompt|
|
||||
custom_prompt_translation =
|
||||
Proc.new do |message|
|
||||
|
@ -51,25 +50,22 @@ module DiscourseAi
|
|||
|
||||
custom_context[:name] = message[1] if custom_context[:type] != "assistant"
|
||||
|
||||
result << custom_context
|
||||
custom_context
|
||||
end
|
||||
end
|
||||
|
||||
if custom_prompt.present?
|
||||
if first
|
||||
custom_prompt.reverse_each(&custom_prompt_translation)
|
||||
first = false
|
||||
else
|
||||
tool_call_and_tool = custom_prompt.first(2)
|
||||
tool_call_and_tool.reverse_each(&custom_prompt_translation)
|
||||
end
|
||||
result << {
|
||||
type: "multi_turn",
|
||||
content: custom_prompt.reverse_each.map(&custom_prompt_translation).compact,
|
||||
}
|
||||
else
|
||||
context = {
|
||||
content: raw,
|
||||
type: (available_bot_usernames.include?(username) ? "assistant" : "user"),
|
||||
}
|
||||
|
||||
context[:name] = username if context[:type] == "user"
|
||||
context[:name] = clean_username(username) if context[:type] == "user"
|
||||
|
||||
result << context
|
||||
end
|
||||
|
|
|
@ -65,7 +65,8 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return [] if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
flattened_context = flatten_context(prompt[:conversation_context])
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context.reverse.map do |context|
|
||||
if context[:type] == "tool_call"
|
||||
|
|
|
@ -40,7 +40,8 @@ module DiscourseAi
|
|||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
flattened_context = flatten_context(clean_context)
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
|
|
|
@ -164,6 +164,27 @@ module DiscourseAi
|
|||
#{tools}</tools>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def flatten_context(context)
|
||||
found_first_multi_turn = false
|
||||
|
||||
context
|
||||
.map do |a_context|
|
||||
if a_context[:type] == "multi_turn"
|
||||
if found_first_multi_turn
|
||||
# Only take tool and tool_call_id from subsequent multi-turn interactions.
|
||||
# Drop assistant responses
|
||||
a_context[:content].last(2)
|
||||
else
|
||||
found_first_multi_turn = true
|
||||
a_context[:content]
|
||||
end
|
||||
else
|
||||
a_context
|
||||
end
|
||||
end
|
||||
.flatten
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -15,6 +15,9 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def translate
|
||||
# Gemini complains if we don't alternate model/user roles.
|
||||
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
||||
|
||||
gemini_prompt = [
|
||||
{
|
||||
role: "user",
|
||||
|
@ -22,7 +25,7 @@ module DiscourseAi
|
|||
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
|
||||
},
|
||||
},
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
noop_model_response,
|
||||
]
|
||||
|
||||
if prompt[:examples]
|
||||
|
@ -34,7 +37,13 @@ module DiscourseAi
|
|||
|
||||
gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
|
||||
|
||||
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
||||
if prompt[:input]
|
||||
gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user"
|
||||
|
||||
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
||||
end
|
||||
|
||||
gemini_prompt
|
||||
end
|
||||
|
||||
def tools
|
||||
|
@ -42,16 +51,23 @@ module DiscourseAi
|
|||
|
||||
translated_tools =
|
||||
prompt[:tools].map do |t|
|
||||
required_fields = []
|
||||
tool = t.dup
|
||||
tool = t.slice(:name, :description)
|
||||
|
||||
tool[:parameters] = t[:parameters].map do |p|
|
||||
required_fields << p[:name] if p[:required]
|
||||
if t[:parameters]
|
||||
tool[:parameters] = t[:parameters].reduce(
|
||||
{ type: "object", required: [], properties: {} },
|
||||
) do |memo, p|
|
||||
name = p[:name]
|
||||
memo[:required] << name if p[:required]
|
||||
|
||||
p.except(:required)
|
||||
memo[:properties][name] = p.except(:name, :required, :item_type)
|
||||
|
||||
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
tool.merge(required: required_fields)
|
||||
tool
|
||||
end
|
||||
|
||||
[{ function_declarations: translated_tools }]
|
||||
|
@ -60,23 +76,42 @@ module DiscourseAi
|
|||
def conversation_context
|
||||
return [] if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
flattened_context = flatten_context(prompt[:conversation_context])
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context.reverse.map do |context|
|
||||
translated = {}
|
||||
translated[:role] = (context[:type] == "user" ? "user" : "model")
|
||||
if context[:type] == "tool_call"
|
||||
function = JSON.parse(context[:content], symbolize_names: true)
|
||||
|
||||
part = {}
|
||||
|
||||
if context[:type] == "tool"
|
||||
part["functionResponse"] = { name: context[:name], content: context[:content] }
|
||||
{
|
||||
role: "model",
|
||||
parts: {
|
||||
functionCall: {
|
||||
name: function[:name],
|
||||
args: function[:arguments],
|
||||
},
|
||||
},
|
||||
}
|
||||
elsif context[:type] == "tool"
|
||||
{
|
||||
role: "function",
|
||||
parts: {
|
||||
functionResponse: {
|
||||
name: context[:name],
|
||||
response: {
|
||||
content: context[:content],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
else
|
||||
part[:text] = context[:content]
|
||||
{
|
||||
role: context[:type] == "assistant" ? "model" : "user",
|
||||
parts: {
|
||||
text: context[:content],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
translated[:parts] = [part]
|
||||
|
||||
translated
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -89,6 +124,19 @@ module DiscourseAi
|
|||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def flatten_context(context)
|
||||
context.map do |a_context|
|
||||
if a_context[:type] == "multi_turn"
|
||||
# Drop old tool calls and only keep bot response.
|
||||
a_context[:content].find { |c| c[:type] == "assistant" }
|
||||
else
|
||||
a_context
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -40,8 +40,8 @@ module DiscourseAi
|
|||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
|
||||
trimmed_context = trim_context(clean_context)
|
||||
flattened_context = flatten_context(clean_context)
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
|
|
|
@ -40,7 +40,8 @@ module DiscourseAi
|
|||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
flattened_context = flatten_context(clean_context)
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
|
|
|
@ -37,7 +37,8 @@ module DiscourseAi
|
|||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
clean_context = prompt[:conversation_context].select { |cc| cc[:type] != "tool_call" }
|
||||
trimmed_context = trim_context(clean_context)
|
||||
flattened_context = flatten_context(clean_context)
|
||||
trimmed_context = trim_context(flattened_context)
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
|
|
|
@ -273,20 +273,19 @@ module DiscourseAi
|
|||
function_buffer.at("tool_id").inner_html = tool_name
|
||||
end
|
||||
|
||||
_read_parameters =
|
||||
read_function
|
||||
.at("parameters")
|
||||
&.elements
|
||||
.to_a
|
||||
.each do |elem|
|
||||
if paramenter = function_buffer.at(elem.name)&.text
|
||||
function_buffer.at(elem.name).inner_html = paramenter
|
||||
else
|
||||
param_node = read_function.at(elem.name)
|
||||
function_buffer.at("parameters").add_child(param_node)
|
||||
function_buffer.at("parameters").add_child("\n")
|
||||
end
|
||||
read_function
|
||||
.at("parameters")
|
||||
&.elements
|
||||
.to_a
|
||||
.each do |elem|
|
||||
if paramenter = function_buffer.at(elem.name)&.text
|
||||
function_buffer.at(elem.name).inner_html = paramenter
|
||||
else
|
||||
param_node = read_function.at(elem.name)
|
||||
function_buffer.at("parameters").add_child(param_node)
|
||||
function_buffer.at("parameters").add_child("\n")
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
end
|
||||
|
|
|
@ -42,10 +42,12 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
tools = dialect.tools
|
||||
|
||||
default_options
|
||||
.merge(contents: prompt)
|
||||
.tap do |payload|
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
payload[:tools] = tools if tools.present?
|
||||
payload[:generationConfig].merge!(model_params) if model_params.present?
|
||||
end
|
||||
end
|
||||
|
@ -57,8 +59,12 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
parsed =
|
||||
if @streaming_mode
|
||||
response_raw
|
||||
else
|
||||
JSON.parse(response_raw, symbolize_names: true)
|
||||
end
|
||||
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
||||
|
||||
@has_function_call ||= response_h.dig(:functionCall).present?
|
||||
|
@ -66,20 +72,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
if line == ","
|
||||
nil
|
||||
elsif line.starts_with?("[")
|
||||
line[1..-1]
|
||||
elsif line.ends_with?("]")
|
||||
line[0..-1]
|
||||
else
|
||||
line
|
||||
end
|
||||
end
|
||||
.compact_blank
|
||||
begin
|
||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
||||
rescue JSON::ParserError
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
|
|
|
@ -62,6 +62,8 @@ module DiscourseAi
|
|||
# { type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
# { type: "tool_call_id", name: "tool_id", content: { name: "tool", args: { ...tool_args } } },
|
||||
# { type: "multi_turn", content: [assistant_reply_from_a_tool, tool_call, tool_call_id] }
|
||||
# ]
|
||||
#
|
||||
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
|
||||
|
|
|
@ -98,18 +98,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
expect(translated_context).to eq(
|
||||
[
|
||||
{
|
||||
role: "model",
|
||||
parts: [
|
||||
{
|
||||
"functionResponse" => {
|
||||
name: context.last[:name],
|
||||
role: "function",
|
||||
parts: {
|
||||
functionResponse: {
|
||||
name: context.last[:name],
|
||||
response: {
|
||||
content: context.last[:content],
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
{ role: "model", parts: [{ text: context.second[:content] }] },
|
||||
{ role: "user", parts: [{ text: context.first[:content] }] },
|
||||
{ role: "model", parts: { text: context.second[:content] } },
|
||||
{ role: "user", parts: { text: context.first[:content] } },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
@ -121,7 +121,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
|
||||
expect(translated_context.last.dig(:parts, :text).length).to be <
|
||||
context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
@ -133,16 +133,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name" },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: %w[location unit],
|
||||
properties: {
|
||||
"location" => {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
"unit" => {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
],
|
||||
required: %w[location unit],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
|
@ -16,16 +16,21 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name" },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: %w[location unit],
|
||||
properties: {
|
||||
"location" => {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
"unit" => {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
],
|
||||
required: %w[location unit],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -126,7 +131,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
|
||||
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
|
|
|
@ -45,6 +45,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
|||
<summary>#{tool.summary}</summary>
|
||||
<p></p>
|
||||
</details>
|
||||
|
||||
HTML
|
||||
|
||||
context = {}
|
||||
|
|
|
@ -125,44 +125,18 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: custom_prompt.third.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
{
|
||||
type: "multi_turn",
|
||||
content: [
|
||||
{ type: "assistant", content: custom_prompt.third.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
],
|
||||
},
|
||||
{ type: "user", name: user.username, content: first_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "include replies generated from tools only once" do
|
||||
custom_prompt = [
|
||||
[
|
||||
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
|
||||
"time",
|
||||
"tool",
|
||||
],
|
||||
[
|
||||
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" }.to_json }.to_json,
|
||||
"time",
|
||||
"tool_call",
|
||||
],
|
||||
["I replied this thanks to the time command", bot_user.username],
|
||||
]
|
||||
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
|
||||
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: "user", name: user.username, content: third_post.raw },
|
||||
{ type: "assistant", content: custom_prompt.third.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
{ type: "tool_call", content: custom_prompt.second.first, name: "time" },
|
||||
{ type: "tool", name: "time", content: custom_prompt.first.first },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue