From d5c23f01ff69cd6641160ba14030c2c31e092fd5 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 22 May 2024 16:35:29 +1000 Subject: [PATCH] FIX: correct gemini streaming implementation (#632) This also implements image support and gemini-flash support --- config/locales/client.en.yml | 1 + lib/ai_bot/bot.rb | 2 +- lib/automation.rb | 1 + lib/completions/dialects/gemini.rb | 86 +++++++++---- lib/completions/endpoints/gemini.rb | 81 ++++++++++-- lib/completions/llm.rb | 2 +- lib/summarization/entry_point.rb | 1 + spec/lib/completions/dialects/gemini_spec.rb | 119 +++++++++--------- spec/lib/completions/endpoints/gemini_spec.rb | 103 +++++++++++---- 9 files changed, 277 insertions(+), 119 deletions(-) diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 38832723..fa598dc0 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -13,6 +13,7 @@ en: claude_2: Claude 2 gemini_pro: Gemini Pro gemini_1_5_pro: Gemini 1.5 Pro + gemini_1_5_flash: Gemini 1.5 Flash claude_3_opus: Claude 3 Opus claude_3_sonnet: Claude 3 Sonnet claude_3_haiku: Claude 3 Haiku diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index b97a6abc..90526b57 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -178,7 +178,7 @@ module DiscourseAi "ollama:mistral" end when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "google:gemini-pro" + "google:gemini-1.5-pro" when DiscourseAi::AiBot::EntryPoint::FAKE_ID "fake:fake" when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID diff --git a/lib/automation.rb b/lib/automation.rb index ba0a36ed..5f775c67 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -9,6 +9,7 @@ module DiscourseAi { id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" }, { id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" }, { id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" }, + { id: "gemini-1.5-flash", name: "discourse_automation.ai_models.gemini_1_5_flash" }, { id: "claude-2", name: "discourse_automation.ai_models.claude_2" }, { id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" }, { id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" }, diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 6b053f1a..a6e04c23 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -6,7 +6,7 @@ module DiscourseAi class Gemini < Dialect class << self def can_translate?(model_name) - %w[gemini-pro gemini-1.5-pro].include?(model_name) + %w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name) end end @@ -26,7 +26,13 @@ module DiscourseAi interleving_messages = [] previous_message = nil + system_instruction = nil + messages.each do |message| + if message[:role] == "system" + system_instruction = message[:content] + next + end if previous_message if (previous_message[:role] == "user" || previous_message[:role] == "function") && message[:role] == "user" @@ -37,7 +43,7 @@ module DiscourseAi previous_message = message end - interleving_messages + { messages: interleving_messages, system_instruction: system_instruction } end def tools @@ -70,7 +76,7 @@ module DiscourseAi def max_prompt_tokens return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens - if model_name == "gemini-1.5-pro" + if model_name.start_with?("gemini-1.5") # technically we support 1 million tokens, but we're being conservative 800_000 else @@ -84,44 +90,80 @@ module DiscourseAi self.tokenizer.size(context[:content].to_s + context[:name].to_s) end + def beta_api? + @beta_api ||= model_name.start_with?("gemini-1.5") + end + def system_msg(msg) - { role: "user", parts: { text: msg[:content] } } + if beta_api? + { role: "system", content: msg[:content] } + else + { role: "user", parts: { text: msg[:content] } } + end end def model_msg(msg) - { role: "model", parts: { text: msg[:content] } } + if beta_api? + { role: "model", parts: [{ text: msg[:content] }] } + else + { role: "model", parts: { text: msg[:content] } } + end end def user_msg(msg) - { role: "user", parts: { text: msg[:content] } } + if beta_api? + # support new format with multiple parts + result = { role: "user", parts: [{ text: msg[:content] }] } + upload_parts = uploaded_parts(msg) + result[:parts].concat(upload_parts) if upload_parts.present? + result + else + { role: "user", parts: { text: msg[:content] } } + end + end + + def uploaded_parts(message) + encoded_uploads = prompt.encoded_uploads(message) + result = [] + if encoded_uploads.present? + encoded_uploads.each do |details| + result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } } + end + end + result end def tool_call_msg(msg) call_details = JSON.parse(msg[:content], symbolize_names: true) - - { - role: "model", - parts: { - functionCall: { - name: msg[:name] || call_details[:name], - args: call_details[:arguments], - }, + part = { + functionCall: { + name: msg[:name] || call_details[:name], + args: call_details[:arguments], }, } + + if beta_api? + { role: "model", parts: [part] } + else + { role: "model", parts: part } + end end def tool_msg(msg) - { - role: "function", - parts: { - functionResponse: { - name: msg[:name] || msg[:id], - response: { - content: msg[:content], - }, + part = { + functionResponse: { + name: msg[:name] || msg[:id], + response: { + content: msg[:content], }, }, } + + if beta_api? + { role: "function", parts: [part] } + else + { role: "function", parts: part } + end end end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index c93b9bdf..39dd320e 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -54,13 +54,24 @@ module DiscourseAi if llm_model url = llm_model.url else - mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model + mapped_model = model + if model == "gemini-1.5-pro" + mapped_model = "gemini-1.5-pro-latest" + elsif model == "gemini-1.5-flash" + mapped_model = "gemini-1.5-flash-latest" + elsif model == "gemini-1.0-pro" + mapped_model = "gemini-pro-latest" + end url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}" end key = llm_model&.api_key || SiteSetting.ai_gemini_api_key - url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}" + if @streaming_mode + url = "#{url}:streamGenerateContent?key=#{key}&alt=sse" + else + url = "#{url}:generateContent?key=#{key}" + end URI(url) end @@ -68,12 +79,14 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) tools = dialect.tools - default_options - .merge(contents: prompt) - .tap do |payload| - payload[:tools] = tools if tools.present? - payload[:generationConfig].merge!(model_params) if model_params.present? - end + payload = default_options.merge(contents: prompt[:messages]) + payload[:systemInstruction] = { + role: "system", + parts: [{ text: prompt[:system_instruction].to_s }], + } if prompt[:system_instruction].present? + payload[:tools] = tools if tools.present? + payload[:generationConfig].merge!(model_params) if model_params.present? + payload end def prepare_request(payload) @@ -96,11 +109,55 @@ module DiscourseAi end def partials_from(decoded_chunk) - begin - JSON.parse(decoded_chunk, symbolize_names: true) - rescue JSON::ParserError - [] + decoded_chunk + end + + def chunk_to_string(chunk) + chunk.to_s + end + + class Decoder + def initialize + @buffer = +"" end + + def decode(str) + @buffer << str + + lines = @buffer.split(/\r?\n\r?\n/) + + keep_last = false + + decoded = + lines + .map do |line| + if line.start_with?("data: {") + begin + JSON.parse(line[6..-1], symbolize_names: true) + rescue JSON::ParserError + keep_last = line + nil + end + else + keep_last = line + nil + end + end + .compact + + if keep_last + @buffer = +(keep_last) + else + @buffer = +"" + end + + decoded + end + end + + def decode(chunk) + @decoder ||= Decoder.new + @decoder.decode(chunk) end def extract_prompt_for_tokenizer(prompt) diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 531dacfa..2c71adfe 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -56,7 +56,7 @@ module DiscourseAi gpt-4-vision-preview gpt-4o ], - google: %w[gemini-pro gemini-1.5-pro], + google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash], }.tap do |h| h[:ollama] = ["mistral"] if Rails.env.development? h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 8f42c346..960a2b3e 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -13,6 +13,7 @@ module DiscourseAi Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000), + Models::Gemini.new("google:gemini-1.5-flash", max_tokens: 800_000), ] claude_prov = "anthropic" diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 338eaadd..6906efd2 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -3,16 +3,15 @@ require_relative "dialect_context" RSpec.describe DiscourseAi::Completions::Dialects::Gemini do - let(:model_name) { "gemini-pro" } + let(:model_name) { "gemini-1.5-pro" } let(:context) { DialectContext.new(described_class, model_name) } describe "#translate" do it "translates a prompt written in our generic format to the Gemini format" do - gemini_version = [ - { role: "user", parts: { text: context.system_insts } }, - { role: "model", parts: { text: "Ok." } }, - { role: "user", parts: { text: context.simple_user_input } }, - ] + gemini_version = { + messages: [{ role: "user", parts: [{ text: context.simple_user_input }] }], + system_instruction: context.system_insts, + } translated = context.system_user_scenario @@ -21,73 +20,77 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do it "injects model after tool call" do expect(context.image_generation_scenario).to eq( - [ - { role: "user", parts: { text: context.system_insts } }, - { parts: { text: "Ok." }, role: "model" }, - { parts: { text: "draw a cat" }, role: "user" }, - { parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" }, - { - parts: { - functionResponse: { - name: "tool_id", - response: { - content: "\"I'm a tool result\"", - }, - }, + { + messages: [ + { role: "user", parts: [{ text: "draw a cat" }] }, + { + role: "model", + parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }], }, - role: "function", - }, - { parts: { text: "Ok." }, role: "model" }, - { parts: { text: "draw another cat" }, role: "user" }, - ], + { + role: "function", + parts: [ + { + functionResponse: { + name: "tool_id", + response: { + content: "\"I'm a tool result\"", + }, + }, + }, + ], + }, + { role: "model", parts: { text: "Ok." } }, + { role: "user", parts: [{ text: "draw another cat" }] }, + ], + system_instruction: context.system_insts, + }, ) end it "translates tool_call and tool messages" do expect(context.multi_turn_scenario).to eq( - [ - { role: "user", parts: { text: context.system_insts } }, - { role: "model", parts: { text: "Ok." } }, - { role: "user", parts: { text: "This is a message by a user" } }, - { - role: "model", - parts: { - text: "I'm a previous bot reply, that's why there's no user", + { + messages: [ + { role: "user", parts: [{ text: "This is a message by a user" }] }, + { + role: "model", + parts: [{ text: "I'm a previous bot reply, that's why there's no user" }], }, - }, - { role: "user", parts: { text: "This is a new message by a user" } }, - { - role: "model", - parts: { - functionCall: { - name: "get_weather", - args: { - location: "Sydney", - unit: "c", + { role: "user", parts: [{ text: "This is a new message by a user" }] }, + { + role: "model", + parts: [ + { functionCall: { name: "get_weather", args: { location: "Sydney", unit: "c" } } }, + ], + }, + { + role: "function", + parts: [ + { + functionResponse: { + name: "get_weather", + response: { + content: "\"I'm a tool result\"", + }, + }, }, - }, + ], }, - }, - { - role: "function", - parts: { - functionResponse: { - name: "get_weather", - response: { - content: "I'm a tool result".to_json, - }, - }, - }, - }, - ], + ], + system_instruction: + "I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n", + }, ) end it "trims content if it's getting too long" do + # testing truncation on 800k tokens is slow use model with less + context = DialectContext.new(described_class, "gemini-pro") translated = context.long_user_input_scenario(length: 5_000) - expect(translated.last[:role]).to eq("user") - expect(translated.last.dig(:parts, :text).length).to be < + expect(translated[:messages].last[:role]).to eq("user") + expect(translated[:messages].last.dig(:parts, :text).length).to be < context.long_message_text(length: 5_000).length end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index dbc5fe31..1fe0f5a6 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -132,39 +132,92 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do fab!(:user) + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + let(:upload100x100) do + UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + end + let(:gemini_mock) { GeminiMock.new(endpoint) } let(:compliance) do EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user) end - describe "#perform_completion!" do - context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(gemini_mock) - end - end + it "Supports Vision API" do + SiteSetting.ai_gemini_api_key = "ABC" - context "with tools" do - it "returns a function invocation" do - compliance.regular_mode_tools(gemini_mock) - end - end + prompt = + DiscourseAi::Completions::Prompt.new( + "You are image bot", + messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + ) + + encoded = prompt.encoded_uploads(prompt.messages.last) + + response = gemini_mock.response("World").to_json + + req_body = nil + + llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-pro") + url = + "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=ABC" + + stub_request(:post, url).with( + body: + proc do |_req_body| + req_body = _req_body + true + end, + ).to_return(status: 200, body: response) + + response = llm.generate(prompt, user: user) + + expect(response).to eq("World") + + expected_prompt = { + "generationConfig" => { + }, + "contents" => [ + { + "role" => "user", + "parts" => [ + { "text" => "hello" }, + { "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } }, + ], + }, + ], + "systemInstruction" => { + "role" => "system", + "parts" => [{ "text" => "You are image bot" }], + }, + } + + expect(JSON.parse(req_body)).to eq(expected_prompt) + end + + it "Can correctly handle streamed responses even if they are chunked badly" do + SiteSetting.ai_gemini_api_key = "ABC" + + data = +"" + data << "da|ta: |" + data << gemini_mock.response("Hello").to_json + data << "\r\n\r\ndata: " + data << gemini_mock.response(" |World").to_json + data << "\r\n\r\ndata: " + data << gemini_mock.response(" Sam").to_json + + split = data.split("|") + + llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-flash") + url = + "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:streamGenerateContent?alt=sse&key=ABC" + + output = +"" + gemini_mock.with_chunk_array_support do + stub_request(:post, url).to_return(status: 200, body: split) + llm.generate("Hello", user: user) { |partial| output << partial } end - describe "when using streaming mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(gemini_mock) - end - end - - context "with tools" do - it "returns a function invocation" do - compliance.streaming_mode_tools(gemini_mock) - end - end - end + expect(output).to eq("Hello World Sam") end end