FIX: correct gemini streaming implementation (#632)

This also implements image support and gemini-flash support
This commit is contained in:
Sam 2024-05-22 16:35:29 +10:00 committed by GitHub
parent 06137ac706
commit d5c23f01ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 277 additions and 119 deletions

View File

@ -13,6 +13,7 @@ en:
claude_2: Claude 2
gemini_pro: Gemini Pro
gemini_1_5_pro: Gemini 1.5 Pro
gemini_1_5_flash: Gemini 1.5 Flash
claude_3_opus: Claude 3 Opus
claude_3_sonnet: Claude 3 Sonnet
claude_3_haiku: Claude 3 Haiku

View File

@ -178,7 +178,7 @@ module DiscourseAi
"ollama:mistral"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
"google:gemini-1.5-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID

View File

@ -9,6 +9,7 @@ module DiscourseAi
{ id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" },
{ id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" },
{ id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" },
{ id: "gemini-1.5-flash", name: "discourse_automation.ai_models.gemini_1_5_flash" },
{ id: "claude-2", name: "discourse_automation.ai_models.claude_2" },
{ id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" },
{ id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" },

View File

@ -6,7 +6,7 @@ module DiscourseAi
class Gemini < Dialect
class << self
def can_translate?(model_name)
%w[gemini-pro gemini-1.5-pro].include?(model_name)
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
end
end
@ -26,7 +26,13 @@ module DiscourseAi
interleving_messages = []
previous_message = nil
system_instruction = nil
messages.each do |message|
if message[:role] == "system"
system_instruction = message[:content]
next
end
if previous_message
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
message[:role] == "user"
@ -37,7 +43,7 @@ module DiscourseAi
previous_message = message
end
interleving_messages
{ messages: interleving_messages, system_instruction: system_instruction }
end
def tools
@ -70,7 +76,7 @@ module DiscourseAi
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
if model_name == "gemini-1.5-pro"
if model_name.start_with?("gemini-1.5")
# technically we support 1 million tokens, but we're being conservative
800_000
else
@ -84,44 +90,80 @@ module DiscourseAi
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def beta_api?
@beta_api ||= model_name.start_with?("gemini-1.5")
end
def system_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
if beta_api?
{ role: "system", content: msg[:content] }
else
{ role: "user", parts: { text: msg[:content] } }
end
end
def model_msg(msg)
{ role: "model", parts: { text: msg[:content] } }
if beta_api?
{ role: "model", parts: [{ text: msg[:content] }] }
else
{ role: "model", parts: { text: msg[:content] } }
end
end
def user_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
if beta_api?
# support new format with multiple parts
result = { role: "user", parts: [{ text: msg[:content] }] }
upload_parts = uploaded_parts(msg)
result[:parts].concat(upload_parts) if upload_parts.present?
result
else
{ role: "user", parts: { text: msg[:content] } }
end
end
def uploaded_parts(message)
encoded_uploads = prompt.encoded_uploads(message)
result = []
if encoded_uploads.present?
encoded_uploads.each do |details|
result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
end
end
result
end
def tool_call_msg(msg)
call_details = JSON.parse(msg[:content], symbolize_names: true)
{
role: "model",
parts: {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
part = {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
}
if beta_api?
{ role: "model", parts: [part] }
else
{ role: "model", parts: part }
end
end
def tool_msg(msg)
{
role: "function",
parts: {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
part = {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
},
}
if beta_api?
{ role: "function", parts: [part] }
else
{ role: "function", parts: part }
end
end
end
end

View File

@ -54,13 +54,24 @@ module DiscourseAi
if llm_model
url = llm_model.url
else
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
mapped_model = model
if model == "gemini-1.5-pro"
mapped_model = "gemini-1.5-pro-latest"
elsif model == "gemini-1.5-flash"
mapped_model = "gemini-1.5-flash-latest"
elsif model == "gemini-1.0-pro"
mapped_model = "gemini-pro-latest"
end
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
end
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
if @streaming_mode
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
else
url = "#{url}:generateContent?key=#{key}"
end
URI(url)
end
@ -68,12 +79,14 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
tools = dialect.tools
default_options
.merge(contents: prompt)
.tap do |payload|
payload[:tools] = tools if tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
end
payload = default_options.merge(contents: prompt[:messages])
payload[:systemInstruction] = {
role: "system",
parts: [{ text: prompt[:system_instruction].to_s }],
} if prompt[:system_instruction].present?
payload[:tools] = tools if tools.present?
payload[:generationConfig].merge!(model_params) if model_params.present?
payload
end
def prepare_request(payload)
@ -96,11 +109,55 @@ module DiscourseAi
end
def partials_from(decoded_chunk)
begin
JSON.parse(decoded_chunk, symbolize_names: true)
rescue JSON::ParserError
[]
decoded_chunk
end
def chunk_to_string(chunk)
chunk.to_s
end
class Decoder
def initialize
@buffer = +""
end
def decode(str)
@buffer << str
lines = @buffer.split(/\r?\n\r?\n/)
keep_last = false
decoded =
lines
.map do |line|
if line.start_with?("data: {")
begin
JSON.parse(line[6..-1], symbolize_names: true)
rescue JSON::ParserError
keep_last = line
nil
end
else
keep_last = line
nil
end
end
.compact
if keep_last
@buffer = +(keep_last)
else
@buffer = +""
end
decoded
end
end
def decode(chunk)
@decoder ||= Decoder.new
@decoder.decode(chunk)
end
def extract_prompt_for_tokenizer(prompt)

View File

@ -56,7 +56,7 @@ module DiscourseAi
gpt-4-vision-preview
gpt-4o
],
google: %w[gemini-pro gemini-1.5-pro],
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
}.tap do |h|
h[:ollama] = ["mistral"] if Rails.env.development?
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?

View File

@ -13,6 +13,7 @@ module DiscourseAi
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
Models::Gemini.new("google:gemini-1.5-flash", max_tokens: 800_000),
]
claude_prov = "anthropic"

View File

@ -3,16 +3,15 @@
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
let(:model_name) { "gemini-pro" }
let(:model_name) { "gemini-1.5-pro" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Gemini format" do
gemini_version = [
{ role: "user", parts: { text: context.system_insts } },
{ role: "model", parts: { text: "Ok." } },
{ role: "user", parts: { text: context.simple_user_input } },
]
gemini_version = {
messages: [{ role: "user", parts: [{ text: context.simple_user_input }] }],
system_instruction: context.system_insts,
}
translated = context.system_user_scenario
@ -21,73 +20,77 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
it "injects model after tool call" do
expect(context.image_generation_scenario).to eq(
[
{ role: "user", parts: { text: context.system_insts } },
{ parts: { text: "Ok." }, role: "model" },
{ parts: { text: "draw a cat" }, role: "user" },
{ parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" },
{
parts: {
functionResponse: {
name: "tool_id",
response: {
content: "\"I'm a tool result\"",
},
},
{
messages: [
{ role: "user", parts: [{ text: "draw a cat" }] },
{
role: "model",
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
},
role: "function",
},
{ parts: { text: "Ok." }, role: "model" },
{ parts: { text: "draw another cat" }, role: "user" },
],
{
role: "function",
parts: [
{
functionResponse: {
name: "tool_id",
response: {
content: "\"I'm a tool result\"",
},
},
},
],
},
{ role: "model", parts: { text: "Ok." } },
{ role: "user", parts: [{ text: "draw another cat" }] },
],
system_instruction: context.system_insts,
},
)
end
it "translates tool_call and tool messages" do
expect(context.multi_turn_scenario).to eq(
[
{ role: "user", parts: { text: context.system_insts } },
{ role: "model", parts: { text: "Ok." } },
{ role: "user", parts: { text: "This is a message by a user" } },
{
role: "model",
parts: {
text: "I'm a previous bot reply, that's why there's no user",
{
messages: [
{ role: "user", parts: [{ text: "This is a message by a user" }] },
{
role: "model",
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
},
},
{ role: "user", parts: { text: "This is a new message by a user" } },
{
role: "model",
parts: {
functionCall: {
name: "get_weather",
args: {
location: "Sydney",
unit: "c",
{ role: "user", parts: [{ text: "This is a new message by a user" }] },
{
role: "model",
parts: [
{ functionCall: { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
],
},
{
role: "function",
parts: [
{
functionResponse: {
name: "get_weather",
response: {
content: "\"I'm a tool result\"",
},
},
},
},
],
},
},
{
role: "function",
parts: {
functionResponse: {
name: "get_weather",
response: {
content: "I'm a tool result".to_json,
},
},
},
},
],
],
system_instruction:
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
},
)
end
it "trims content if it's getting too long" do
# testing truncation on 800k tokens is slow use model with less
context = DialectContext.new(described_class, "gemini-pro")
translated = context.long_user_input_scenario(length: 5_000)
expect(translated.last[:role]).to eq("user")
expect(translated.last.dig(:parts, :text).length).to be <
expect(translated[:messages].last[:role]).to eq("user")
expect(translated[:messages].last.dig(:parts, :text).length).to be <
context.long_message_text(length: 5_000).length
end
end

View File

@ -132,39 +132,92 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
fab!(:user)
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
end
let(:gemini_mock) { GeminiMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(gemini_mock)
end
end
it "Supports Vision API" do
SiteSetting.ai_gemini_api_key = "ABC"
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(gemini_mock)
end
end
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
response = gemini_mock.response("World").to_json
req_body = nil
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-pro")
url =
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=ABC"
stub_request(:post, url).with(
body:
proc do |_req_body|
req_body = _req_body
true
end,
).to_return(status: 200, body: response)
response = llm.generate(prompt, user: user)
expect(response).to eq("World")
expected_prompt = {
"generationConfig" => {
},
"contents" => [
{
"role" => "user",
"parts" => [
{ "text" => "hello" },
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
],
},
],
"systemInstruction" => {
"role" => "system",
"parts" => [{ "text" => "You are image bot" }],
},
}
expect(JSON.parse(req_body)).to eq(expected_prompt)
end
it "Can correctly handle streamed responses even if they are chunked badly" do
SiteSetting.ai_gemini_api_key = "ABC"
data = +""
data << "da|ta: |"
data << gemini_mock.response("Hello").to_json
data << "\r\n\r\ndata: "
data << gemini_mock.response(" |World").to_json
data << "\r\n\r\ndata: "
data << gemini_mock.response(" Sam").to_json
split = data.split("|")
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-flash")
url =
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:streamGenerateContent?alt=sse&key=ABC"
output = +""
gemini_mock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: split)
llm.generate("Hello", user: user) { |partial| output << partial }
end
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(gemini_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.streaming_mode_tools(gemini_mock)
end
end
end
expect(output).to eq("Hello World Sam")
end
end