FIX: correct gemini streaming implementation (#632)
This also implements image support and gemini-flash support
This commit is contained in:
parent
06137ac706
commit
d5c23f01ff
|
@ -13,6 +13,7 @@ en:
|
|||
claude_2: Claude 2
|
||||
gemini_pro: Gemini Pro
|
||||
gemini_1_5_pro: Gemini 1.5 Pro
|
||||
gemini_1_5_flash: Gemini 1.5 Flash
|
||||
claude_3_opus: Claude 3 Opus
|
||||
claude_3_sonnet: Claude 3 Sonnet
|
||||
claude_3_haiku: Claude 3 Haiku
|
||||
|
|
|
@ -178,7 +178,7 @@ module DiscourseAi
|
|||
"ollama:mistral"
|
||||
end
|
||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||
"google:gemini-pro"
|
||||
"google:gemini-1.5-pro"
|
||||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||
"fake:fake"
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
|
||||
|
|
|
@ -9,6 +9,7 @@ module DiscourseAi
|
|||
{ id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" },
|
||||
{ id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" },
|
||||
{ id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" },
|
||||
{ id: "gemini-1.5-flash", name: "discourse_automation.ai_models.gemini_1_5_flash" },
|
||||
{ id: "claude-2", name: "discourse_automation.ai_models.claude_2" },
|
||||
{ id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" },
|
||||
{ id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" },
|
||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
|||
class Gemini < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
||||
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -26,7 +26,13 @@ module DiscourseAi
|
|||
interleving_messages = []
|
||||
previous_message = nil
|
||||
|
||||
system_instruction = nil
|
||||
|
||||
messages.each do |message|
|
||||
if message[:role] == "system"
|
||||
system_instruction = message[:content]
|
||||
next
|
||||
end
|
||||
if previous_message
|
||||
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
|
||||
message[:role] == "user"
|
||||
|
@ -37,7 +43,7 @@ module DiscourseAi
|
|||
previous_message = message
|
||||
end
|
||||
|
||||
interleving_messages
|
||||
{ messages: interleving_messages, system_instruction: system_instruction }
|
||||
end
|
||||
|
||||
def tools
|
||||
|
@ -70,7 +76,7 @@ module DiscourseAi
|
|||
def max_prompt_tokens
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
if model_name == "gemini-1.5-pro"
|
||||
if model_name.start_with?("gemini-1.5")
|
||||
# technically we support 1 million tokens, but we're being conservative
|
||||
800_000
|
||||
else
|
||||
|
@ -84,44 +90,80 @@ module DiscourseAi
|
|||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def beta_api?
|
||||
@beta_api ||= model_name.start_with?("gemini-1.5")
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
{ role: "user", parts: { text: msg[:content] } }
|
||||
if beta_api?
|
||||
{ role: "system", content: msg[:content] }
|
||||
else
|
||||
{ role: "user", parts: { text: msg[:content] } }
|
||||
end
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "model", parts: { text: msg[:content] } }
|
||||
if beta_api?
|
||||
{ role: "model", parts: [{ text: msg[:content] }] }
|
||||
else
|
||||
{ role: "model", parts: { text: msg[:content] } }
|
||||
end
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
{ role: "user", parts: { text: msg[:content] } }
|
||||
if beta_api?
|
||||
# support new format with multiple parts
|
||||
result = { role: "user", parts: [{ text: msg[:content] }] }
|
||||
upload_parts = uploaded_parts(msg)
|
||||
result[:parts].concat(upload_parts) if upload_parts.present?
|
||||
result
|
||||
else
|
||||
{ role: "user", parts: { text: msg[:content] } }
|
||||
end
|
||||
end
|
||||
|
||||
def uploaded_parts(message)
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
result = []
|
||||
if encoded_uploads.present?
|
||||
encoded_uploads.each do |details|
|
||||
result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
|
||||
end
|
||||
end
|
||||
result
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
call_details = JSON.parse(msg[:content], symbolize_names: true)
|
||||
|
||||
{
|
||||
role: "model",
|
||||
parts: {
|
||||
functionCall: {
|
||||
name: msg[:name] || call_details[:name],
|
||||
args: call_details[:arguments],
|
||||
},
|
||||
part = {
|
||||
functionCall: {
|
||||
name: msg[:name] || call_details[:name],
|
||||
args: call_details[:arguments],
|
||||
},
|
||||
}
|
||||
|
||||
if beta_api?
|
||||
{ role: "model", parts: [part] }
|
||||
else
|
||||
{ role: "model", parts: part }
|
||||
end
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
{
|
||||
role: "function",
|
||||
parts: {
|
||||
functionResponse: {
|
||||
name: msg[:name] || msg[:id],
|
||||
response: {
|
||||
content: msg[:content],
|
||||
},
|
||||
part = {
|
||||
functionResponse: {
|
||||
name: msg[:name] || msg[:id],
|
||||
response: {
|
||||
content: msg[:content],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if beta_api?
|
||||
{ role: "function", parts: [part] }
|
||||
else
|
||||
{ role: "function", parts: part }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -54,13 +54,24 @@ module DiscourseAi
|
|||
if llm_model
|
||||
url = llm_model.url
|
||||
else
|
||||
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
||||
mapped_model = model
|
||||
if model == "gemini-1.5-pro"
|
||||
mapped_model = "gemini-1.5-pro-latest"
|
||||
elsif model == "gemini-1.5-flash"
|
||||
mapped_model = "gemini-1.5-flash-latest"
|
||||
elsif model == "gemini-1.0-pro"
|
||||
mapped_model = "gemini-pro-latest"
|
||||
end
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
||||
end
|
||||
|
||||
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
||||
|
||||
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
|
||||
if @streaming_mode
|
||||
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
||||
else
|
||||
url = "#{url}:generateContent?key=#{key}"
|
||||
end
|
||||
|
||||
URI(url)
|
||||
end
|
||||
|
@ -68,12 +79,14 @@ module DiscourseAi
|
|||
def prepare_payload(prompt, model_params, dialect)
|
||||
tools = dialect.tools
|
||||
|
||||
default_options
|
||||
.merge(contents: prompt)
|
||||
.tap do |payload|
|
||||
payload[:tools] = tools if tools.present?
|
||||
payload[:generationConfig].merge!(model_params) if model_params.present?
|
||||
end
|
||||
payload = default_options.merge(contents: prompt[:messages])
|
||||
payload[:systemInstruction] = {
|
||||
role: "system",
|
||||
parts: [{ text: prompt[:system_instruction].to_s }],
|
||||
} if prompt[:system_instruction].present?
|
||||
payload[:tools] = tools if tools.present?
|
||||
payload[:generationConfig].merge!(model_params) if model_params.present?
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -96,11 +109,55 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
begin
|
||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
||||
rescue JSON::ParserError
|
||||
[]
|
||||
decoded_chunk
|
||||
end
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
chunk.to_s
|
||||
end
|
||||
|
||||
class Decoder
|
||||
def initialize
|
||||
@buffer = +""
|
||||
end
|
||||
|
||||
def decode(str)
|
||||
@buffer << str
|
||||
|
||||
lines = @buffer.split(/\r?\n\r?\n/)
|
||||
|
||||
keep_last = false
|
||||
|
||||
decoded =
|
||||
lines
|
||||
.map do |line|
|
||||
if line.start_with?("data: {")
|
||||
begin
|
||||
JSON.parse(line[6..-1], symbolize_names: true)
|
||||
rescue JSON::ParserError
|
||||
keep_last = line
|
||||
nil
|
||||
end
|
||||
else
|
||||
keep_last = line
|
||||
nil
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
if keep_last
|
||||
@buffer = +(keep_last)
|
||||
else
|
||||
@buffer = +""
|
||||
end
|
||||
|
||||
decoded
|
||||
end
|
||||
end
|
||||
|
||||
def decode(chunk)
|
||||
@decoder ||= Decoder.new
|
||||
@decoder.decode(chunk)
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
|
|
|
@ -56,7 +56,7 @@ module DiscourseAi
|
|||
gpt-4-vision-preview
|
||||
gpt-4o
|
||||
],
|
||||
google: %w[gemini-pro gemini-1.5-pro],
|
||||
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
|
||||
}.tap do |h|
|
||||
h[:ollama] = ["mistral"] if Rails.env.development?
|
||||
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
|
||||
|
|
|
@ -13,6 +13,7 @@ module DiscourseAi
|
|||
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
|
||||
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
|
||||
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
|
||||
Models::Gemini.new("google:gemini-1.5-flash", max_tokens: 800_000),
|
||||
]
|
||||
|
||||
claude_prov = "anthropic"
|
||||
|
|
|
@ -3,16 +3,15 @@
|
|||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
let(:model_name) { "gemini-pro" }
|
||||
let(:model_name) { "gemini-1.5-pro" }
|
||||
let(:context) { DialectContext.new(described_class, model_name) }
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the Gemini format" do
|
||||
gemini_version = [
|
||||
{ role: "user", parts: { text: context.system_insts } },
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: { text: context.simple_user_input } },
|
||||
]
|
||||
gemini_version = {
|
||||
messages: [{ role: "user", parts: [{ text: context.simple_user_input }] }],
|
||||
system_instruction: context.system_insts,
|
||||
}
|
||||
|
||||
translated = context.system_user_scenario
|
||||
|
||||
|
@ -21,73 +20,77 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
|
||||
it "injects model after tool call" do
|
||||
expect(context.image_generation_scenario).to eq(
|
||||
[
|
||||
{ role: "user", parts: { text: context.system_insts } },
|
||||
{ parts: { text: "Ok." }, role: "model" },
|
||||
{ parts: { text: "draw a cat" }, role: "user" },
|
||||
{ parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" },
|
||||
{
|
||||
parts: {
|
||||
functionResponse: {
|
||||
name: "tool_id",
|
||||
response: {
|
||||
content: "\"I'm a tool result\"",
|
||||
},
|
||||
},
|
||||
{
|
||||
messages: [
|
||||
{ role: "user", parts: [{ text: "draw a cat" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
|
||||
},
|
||||
role: "function",
|
||||
},
|
||||
{ parts: { text: "Ok." }, role: "model" },
|
||||
{ parts: { text: "draw another cat" }, role: "user" },
|
||||
],
|
||||
{
|
||||
role: "function",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: "tool_id",
|
||||
response: {
|
||||
content: "\"I'm a tool result\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: [{ text: "draw another cat" }] },
|
||||
],
|
||||
system_instruction: context.system_insts,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it "translates tool_call and tool messages" do
|
||||
expect(context.multi_turn_scenario).to eq(
|
||||
[
|
||||
{ role: "user", parts: { text: context.system_insts } },
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: { text: "This is a message by a user" } },
|
||||
{
|
||||
role: "model",
|
||||
parts: {
|
||||
text: "I'm a previous bot reply, that's why there's no user",
|
||||
{
|
||||
messages: [
|
||||
{ role: "user", parts: [{ text: "This is a message by a user" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
|
||||
},
|
||||
},
|
||||
{ role: "user", parts: { text: "This is a new message by a user" } },
|
||||
{
|
||||
role: "model",
|
||||
parts: {
|
||||
functionCall: {
|
||||
name: "get_weather",
|
||||
args: {
|
||||
location: "Sydney",
|
||||
unit: "c",
|
||||
{ role: "user", parts: [{ text: "This is a new message by a user" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [
|
||||
{ functionCall: { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "function",
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: "get_weather",
|
||||
response: {
|
||||
content: "\"I'm a tool result\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
role: "function",
|
||||
parts: {
|
||||
functionResponse: {
|
||||
name: "get_weather",
|
||||
response: {
|
||||
content: "I'm a tool result".to_json,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
],
|
||||
system_instruction:
|
||||
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
# testing truncation on 800k tokens is slow use model with less
|
||||
context = DialectContext.new(described_class, "gemini-pro")
|
||||
translated = context.long_user_input_scenario(length: 5_000)
|
||||
|
||||
expect(translated.last[:role]).to eq("user")
|
||||
expect(translated.last.dig(:parts, :text).length).to be <
|
||||
expect(translated[:messages].last[:role]).to eq("user")
|
||||
expect(translated[:messages].last.dig(:parts, :text).length).to be <
|
||||
context.long_message_text(length: 5_000).length
|
||||
end
|
||||
end
|
||||
|
|
|
@ -132,39 +132,92 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
|
||||
fab!(:user)
|
||||
|
||||
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||
let(:upload100x100) do
|
||||
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||
end
|
||||
|
||||
let(:gemini_mock) { GeminiMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(gemini_mock)
|
||||
end
|
||||
end
|
||||
it "Supports Vision API" do
|
||||
SiteSetting.ai_gemini_api_key = "ABC"
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(gemini_mock)
|
||||
end
|
||||
end
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
|
||||
response = gemini_mock.response("World").to_json
|
||||
|
||||
req_body = nil
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-pro")
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=ABC"
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
proc do |_req_body|
|
||||
req_body = _req_body
|
||||
true
|
||||
end,
|
||||
).to_return(status: 200, body: response)
|
||||
|
||||
response = llm.generate(prompt, user: user)
|
||||
|
||||
expect(response).to eq("World")
|
||||
|
||||
expected_prompt = {
|
||||
"generationConfig" => {
|
||||
},
|
||||
"contents" => [
|
||||
{
|
||||
"role" => "user",
|
||||
"parts" => [
|
||||
{ "text" => "hello" },
|
||||
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
|
||||
],
|
||||
},
|
||||
],
|
||||
"systemInstruction" => {
|
||||
"role" => "system",
|
||||
"parts" => [{ "text" => "You are image bot" }],
|
||||
},
|
||||
}
|
||||
|
||||
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
||||
end
|
||||
|
||||
it "Can correctly handle streamed responses even if they are chunked badly" do
|
||||
SiteSetting.ai_gemini_api_key = "ABC"
|
||||
|
||||
data = +""
|
||||
data << "da|ta: |"
|
||||
data << gemini_mock.response("Hello").to_json
|
||||
data << "\r\n\r\ndata: "
|
||||
data << gemini_mock.response(" |World").to_json
|
||||
data << "\r\n\r\ndata: "
|
||||
data << gemini_mock.response(" Sam").to_json
|
||||
|
||||
split = data.split("|")
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-flash")
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:streamGenerateContent?alt=sse&key=ABC"
|
||||
|
||||
output = +""
|
||||
gemini_mock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: split)
|
||||
llm.generate("Hello", user: user) { |partial| output << partial }
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(gemini_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.streaming_mode_tools(gemini_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
expect(output).to eq("Hello World Sam")
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue