FIX: correct gemini streaming implementation (#632)
This also implements image support and gemini-flash support
This commit is contained in:
parent
06137ac706
commit
d5c23f01ff
|
@ -13,6 +13,7 @@ en:
|
||||||
claude_2: Claude 2
|
claude_2: Claude 2
|
||||||
gemini_pro: Gemini Pro
|
gemini_pro: Gemini Pro
|
||||||
gemini_1_5_pro: Gemini 1.5 Pro
|
gemini_1_5_pro: Gemini 1.5 Pro
|
||||||
|
gemini_1_5_flash: Gemini 1.5 Flash
|
||||||
claude_3_opus: Claude 3 Opus
|
claude_3_opus: Claude 3 Opus
|
||||||
claude_3_sonnet: Claude 3 Sonnet
|
claude_3_sonnet: Claude 3 Sonnet
|
||||||
claude_3_haiku: Claude 3 Haiku
|
claude_3_haiku: Claude 3 Haiku
|
||||||
|
|
|
@ -178,7 +178,7 @@ module DiscourseAi
|
||||||
"ollama:mistral"
|
"ollama:mistral"
|
||||||
end
|
end
|
||||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||||
"google:gemini-pro"
|
"google:gemini-1.5-pro"
|
||||||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||||
"fake:fake"
|
"fake:fake"
|
||||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
|
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
|
||||||
|
|
|
@ -9,6 +9,7 @@ module DiscourseAi
|
||||||
{ id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" },
|
{ id: "gpt-3.5-turbo", name: "discourse_automation.ai_models.gpt_3_5_turbo" },
|
||||||
{ id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" },
|
{ id: "gemini-pro", name: "discourse_automation.ai_models.gemini_pro" },
|
||||||
{ id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" },
|
{ id: "gemini-1.5-pro", name: "discourse_automation.ai_models.gemini_1_5_pro" },
|
||||||
|
{ id: "gemini-1.5-flash", name: "discourse_automation.ai_models.gemini_1_5_flash" },
|
||||||
{ id: "claude-2", name: "discourse_automation.ai_models.claude_2" },
|
{ id: "claude-2", name: "discourse_automation.ai_models.claude_2" },
|
||||||
{ id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" },
|
{ id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" },
|
||||||
{ id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" },
|
{ id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" },
|
||||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
||||||
class Gemini < Dialect
|
class Gemini < Dialect
|
||||||
class << self
|
class << self
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_name)
|
||||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -26,7 +26,13 @@ module DiscourseAi
|
||||||
interleving_messages = []
|
interleving_messages = []
|
||||||
previous_message = nil
|
previous_message = nil
|
||||||
|
|
||||||
|
system_instruction = nil
|
||||||
|
|
||||||
messages.each do |message|
|
messages.each do |message|
|
||||||
|
if message[:role] == "system"
|
||||||
|
system_instruction = message[:content]
|
||||||
|
next
|
||||||
|
end
|
||||||
if previous_message
|
if previous_message
|
||||||
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
|
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
|
||||||
message[:role] == "user"
|
message[:role] == "user"
|
||||||
|
@ -37,7 +43,7 @@ module DiscourseAi
|
||||||
previous_message = message
|
previous_message = message
|
||||||
end
|
end
|
||||||
|
|
||||||
interleving_messages
|
{ messages: interleving_messages, system_instruction: system_instruction }
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools
|
def tools
|
||||||
|
@ -70,7 +76,7 @@ module DiscourseAi
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
if model_name == "gemini-1.5-pro"
|
if model_name.start_with?("gemini-1.5")
|
||||||
# technically we support 1 million tokens, but we're being conservative
|
# technically we support 1 million tokens, but we're being conservative
|
||||||
800_000
|
800_000
|
||||||
else
|
else
|
||||||
|
@ -84,44 +90,80 @@ module DiscourseAi
|
||||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def beta_api?
|
||||||
|
@beta_api ||= model_name.start_with?("gemini-1.5")
|
||||||
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
|
if beta_api?
|
||||||
|
{ role: "system", content: msg[:content] }
|
||||||
|
else
|
||||||
{ role: "user", parts: { text: msg[:content] } }
|
{ role: "user", parts: { text: msg[:content] } }
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def model_msg(msg)
|
def model_msg(msg)
|
||||||
|
if beta_api?
|
||||||
|
{ role: "model", parts: [{ text: msg[:content] }] }
|
||||||
|
else
|
||||||
{ role: "model", parts: { text: msg[:content] } }
|
{ role: "model", parts: { text: msg[:content] } }
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def user_msg(msg)
|
def user_msg(msg)
|
||||||
|
if beta_api?
|
||||||
|
# support new format with multiple parts
|
||||||
|
result = { role: "user", parts: [{ text: msg[:content] }] }
|
||||||
|
upload_parts = uploaded_parts(msg)
|
||||||
|
result[:parts].concat(upload_parts) if upload_parts.present?
|
||||||
|
result
|
||||||
|
else
|
||||||
{ role: "user", parts: { text: msg[:content] } }
|
{ role: "user", parts: { text: msg[:content] } }
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def uploaded_parts(message)
|
||||||
|
encoded_uploads = prompt.encoded_uploads(message)
|
||||||
|
result = []
|
||||||
|
if encoded_uploads.present?
|
||||||
|
encoded_uploads.each do |details|
|
||||||
|
result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
result
|
||||||
|
end
|
||||||
|
|
||||||
def tool_call_msg(msg)
|
def tool_call_msg(msg)
|
||||||
call_details = JSON.parse(msg[:content], symbolize_names: true)
|
call_details = JSON.parse(msg[:content], symbolize_names: true)
|
||||||
|
part = {
|
||||||
{
|
|
||||||
role: "model",
|
|
||||||
parts: {
|
|
||||||
functionCall: {
|
functionCall: {
|
||||||
name: msg[:name] || call_details[:name],
|
name: msg[:name] || call_details[:name],
|
||||||
args: call_details[:arguments],
|
args: call_details[:arguments],
|
||||||
},
|
},
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if beta_api?
|
||||||
|
{ role: "model", parts: [part] }
|
||||||
|
else
|
||||||
|
{ role: "model", parts: part }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_msg(msg)
|
def tool_msg(msg)
|
||||||
{
|
part = {
|
||||||
role: "function",
|
|
||||||
parts: {
|
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: msg[:name] || msg[:id],
|
name: msg[:name] || msg[:id],
|
||||||
response: {
|
response: {
|
||||||
content: msg[:content],
|
content: msg[:content],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if beta_api?
|
||||||
|
{ role: "function", parts: [part] }
|
||||||
|
else
|
||||||
|
{ role: "function", parts: part }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -54,13 +54,24 @@ module DiscourseAi
|
||||||
if llm_model
|
if llm_model
|
||||||
url = llm_model.url
|
url = llm_model.url
|
||||||
else
|
else
|
||||||
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
mapped_model = model
|
||||||
|
if model == "gemini-1.5-pro"
|
||||||
|
mapped_model = "gemini-1.5-pro-latest"
|
||||||
|
elsif model == "gemini-1.5-flash"
|
||||||
|
mapped_model = "gemini-1.5-flash-latest"
|
||||||
|
elsif model == "gemini-1.0-pro"
|
||||||
|
mapped_model = "gemini-pro-latest"
|
||||||
|
end
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
||||||
end
|
end
|
||||||
|
|
||||||
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
||||||
|
|
||||||
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
|
if @streaming_mode
|
||||||
|
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
||||||
|
else
|
||||||
|
url = "#{url}:generateContent?key=#{key}"
|
||||||
|
end
|
||||||
|
|
||||||
URI(url)
|
URI(url)
|
||||||
end
|
end
|
||||||
|
@ -68,12 +79,14 @@ module DiscourseAi
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
tools = dialect.tools
|
tools = dialect.tools
|
||||||
|
|
||||||
default_options
|
payload = default_options.merge(contents: prompt[:messages])
|
||||||
.merge(contents: prompt)
|
payload[:systemInstruction] = {
|
||||||
.tap do |payload|
|
role: "system",
|
||||||
|
parts: [{ text: prompt[:system_instruction].to_s }],
|
||||||
|
} if prompt[:system_instruction].present?
|
||||||
payload[:tools] = tools if tools.present?
|
payload[:tools] = tools if tools.present?
|
||||||
payload[:generationConfig].merge!(model_params) if model_params.present?
|
payload[:generationConfig].merge!(model_params) if model_params.present?
|
||||||
end
|
payload
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
|
@ -96,11 +109,55 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
begin
|
decoded_chunk
|
||||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
|
||||||
rescue JSON::ParserError
|
|
||||||
[]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def chunk_to_string(chunk)
|
||||||
|
chunk.to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
class Decoder
|
||||||
|
def initialize
|
||||||
|
@buffer = +""
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(str)
|
||||||
|
@buffer << str
|
||||||
|
|
||||||
|
lines = @buffer.split(/\r?\n\r?\n/)
|
||||||
|
|
||||||
|
keep_last = false
|
||||||
|
|
||||||
|
decoded =
|
||||||
|
lines
|
||||||
|
.map do |line|
|
||||||
|
if line.start_with?("data: {")
|
||||||
|
begin
|
||||||
|
JSON.parse(line[6..-1], symbolize_names: true)
|
||||||
|
rescue JSON::ParserError
|
||||||
|
keep_last = line
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
else
|
||||||
|
keep_last = line
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
|
||||||
|
if keep_last
|
||||||
|
@buffer = +(keep_last)
|
||||||
|
else
|
||||||
|
@buffer = +""
|
||||||
|
end
|
||||||
|
|
||||||
|
decoded
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(chunk)
|
||||||
|
@decoder ||= Decoder.new
|
||||||
|
@decoder.decode(chunk)
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
|
|
|
@ -56,7 +56,7 @@ module DiscourseAi
|
||||||
gpt-4-vision-preview
|
gpt-4-vision-preview
|
||||||
gpt-4o
|
gpt-4o
|
||||||
],
|
],
|
||||||
google: %w[gemini-pro gemini-1.5-pro],
|
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
|
||||||
}.tap do |h|
|
}.tap do |h|
|
||||||
h[:ollama] = ["mistral"] if Rails.env.development?
|
h[:ollama] = ["mistral"] if Rails.env.development?
|
||||||
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
|
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
|
||||||
|
|
|
@ -13,6 +13,7 @@ module DiscourseAi
|
||||||
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
|
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
|
||||||
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
|
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
|
||||||
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
|
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
|
||||||
|
Models::Gemini.new("google:gemini-1.5-flash", max_tokens: 800_000),
|
||||||
]
|
]
|
||||||
|
|
||||||
claude_prov = "anthropic"
|
claude_prov = "anthropic"
|
||||||
|
|
|
@ -3,16 +3,15 @@
|
||||||
require_relative "dialect_context"
|
require_relative "dialect_context"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
let(:model_name) { "gemini-pro" }
|
let(:model_name) { "gemini-1.5-pro" }
|
||||||
let(:context) { DialectContext.new(described_class, model_name) }
|
let(:context) { DialectContext.new(described_class, model_name) }
|
||||||
|
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
it "translates a prompt written in our generic format to the Gemini format" do
|
it "translates a prompt written in our generic format to the Gemini format" do
|
||||||
gemini_version = [
|
gemini_version = {
|
||||||
{ role: "user", parts: { text: context.system_insts } },
|
messages: [{ role: "user", parts: [{ text: context.simple_user_input }] }],
|
||||||
{ role: "model", parts: { text: "Ok." } },
|
system_instruction: context.system_insts,
|
||||||
{ role: "user", parts: { text: context.simple_user_input } },
|
}
|
||||||
]
|
|
||||||
|
|
||||||
translated = context.system_user_scenario
|
translated = context.system_user_scenario
|
||||||
|
|
||||||
|
@ -21,13 +20,17 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
|
|
||||||
it "injects model after tool call" do
|
it "injects model after tool call" do
|
||||||
expect(context.image_generation_scenario).to eq(
|
expect(context.image_generation_scenario).to eq(
|
||||||
[
|
|
||||||
{ role: "user", parts: { text: context.system_insts } },
|
|
||||||
{ parts: { text: "Ok." }, role: "model" },
|
|
||||||
{ parts: { text: "draw a cat" }, role: "user" },
|
|
||||||
{ parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" },
|
|
||||||
{
|
{
|
||||||
parts: {
|
messages: [
|
||||||
|
{ role: "user", parts: [{ text: "draw a cat" }] },
|
||||||
|
{
|
||||||
|
role: "model",
|
||||||
|
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "function",
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: "tool_id",
|
name: "tool_id",
|
||||||
response: {
|
response: {
|
||||||
|
@ -35,59 +38,59 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
role: "function",
|
|
||||||
},
|
|
||||||
{ parts: { text: "Ok." }, role: "model" },
|
|
||||||
{ parts: { text: "draw another cat" }, role: "user" },
|
|
||||||
],
|
],
|
||||||
|
},
|
||||||
|
{ role: "model", parts: { text: "Ok." } },
|
||||||
|
{ role: "user", parts: [{ text: "draw another cat" }] },
|
||||||
|
],
|
||||||
|
system_instruction: context.system_insts,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "translates tool_call and tool messages" do
|
it "translates tool_call and tool messages" do
|
||||||
expect(context.multi_turn_scenario).to eq(
|
expect(context.multi_turn_scenario).to eq(
|
||||||
[
|
{
|
||||||
{ role: "user", parts: { text: context.system_insts } },
|
messages: [
|
||||||
{ role: "model", parts: { text: "Ok." } },
|
{ role: "user", parts: [{ text: "This is a message by a user" }] },
|
||||||
{ role: "user", parts: { text: "This is a message by a user" } },
|
|
||||||
{
|
{
|
||||||
role: "model",
|
role: "model",
|
||||||
parts: {
|
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
|
||||||
text: "I'm a previous bot reply, that's why there's no user",
|
|
||||||
},
|
},
|
||||||
},
|
{ role: "user", parts: [{ text: "This is a new message by a user" }] },
|
||||||
{ role: "user", parts: { text: "This is a new message by a user" } },
|
|
||||||
{
|
{
|
||||||
role: "model",
|
role: "model",
|
||||||
parts: {
|
parts: [
|
||||||
functionCall: {
|
{ functionCall: { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||||
name: "get_weather",
|
],
|
||||||
args: {
|
|
||||||
location: "Sydney",
|
|
||||||
unit: "c",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: "function",
|
role: "function",
|
||||||
parts: {
|
parts: [
|
||||||
|
{
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: "get_weather",
|
name: "get_weather",
|
||||||
response: {
|
response: {
|
||||||
content: "I'm a tool result".to_json,
|
content: "\"I'm a tool result\"",
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
system_instruction:
|
||||||
|
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "trims content if it's getting too long" do
|
it "trims content if it's getting too long" do
|
||||||
|
# testing truncation on 800k tokens is slow use model with less
|
||||||
|
context = DialectContext.new(described_class, "gemini-pro")
|
||||||
translated = context.long_user_input_scenario(length: 5_000)
|
translated = context.long_user_input_scenario(length: 5_000)
|
||||||
|
|
||||||
expect(translated.last[:role]).to eq("user")
|
expect(translated[:messages].last[:role]).to eq("user")
|
||||||
expect(translated.last.dig(:parts, :text).length).to be <
|
expect(translated[:messages].last.dig(:parts, :text).length).to be <
|
||||||
context.long_message_text(length: 5_000).length
|
context.long_message_text(length: 5_000).length
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -132,39 +132,92 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
|
let(:upload100x100) do
|
||||||
|
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||||
|
end
|
||||||
|
|
||||||
let(:gemini_mock) { GeminiMock.new(endpoint) }
|
let(:gemini_mock) { GeminiMock.new(endpoint) }
|
||||||
|
|
||||||
let(:compliance) do
|
let(:compliance) do
|
||||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#perform_completion!" do
|
it "Supports Vision API" do
|
||||||
context "when using regular mode" do
|
SiteSetting.ai_gemini_api_key = "ABC"
|
||||||
context "with simple prompts" do
|
|
||||||
it "completes a trivial prompt and logs the response" do
|
prompt =
|
||||||
compliance.regular_mode_simple_prompt(gemini_mock)
|
DiscourseAi::Completions::Prompt.new(
|
||||||
end
|
"You are image bot",
|
||||||
|
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||||
|
|
||||||
|
response = gemini_mock.response("World").to_json
|
||||||
|
|
||||||
|
req_body = nil
|
||||||
|
|
||||||
|
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-pro")
|
||||||
|
url =
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=ABC"
|
||||||
|
|
||||||
|
stub_request(:post, url).with(
|
||||||
|
body:
|
||||||
|
proc do |_req_body|
|
||||||
|
req_body = _req_body
|
||||||
|
true
|
||||||
|
end,
|
||||||
|
).to_return(status: 200, body: response)
|
||||||
|
|
||||||
|
response = llm.generate(prompt, user: user)
|
||||||
|
|
||||||
|
expect(response).to eq("World")
|
||||||
|
|
||||||
|
expected_prompt = {
|
||||||
|
"generationConfig" => {
|
||||||
|
},
|
||||||
|
"contents" => [
|
||||||
|
{
|
||||||
|
"role" => "user",
|
||||||
|
"parts" => [
|
||||||
|
{ "text" => "hello" },
|
||||||
|
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"systemInstruction" => {
|
||||||
|
"role" => "system",
|
||||||
|
"parts" => [{ "text" => "You are image bot" }],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with tools" do
|
it "Can correctly handle streamed responses even if they are chunked badly" do
|
||||||
it "returns a function invocation" do
|
SiteSetting.ai_gemini_api_key = "ABC"
|
||||||
compliance.regular_mode_tools(gemini_mock)
|
|
||||||
end
|
data = +""
|
||||||
end
|
data << "da|ta: |"
|
||||||
|
data << gemini_mock.response("Hello").to_json
|
||||||
|
data << "\r\n\r\ndata: "
|
||||||
|
data << gemini_mock.response(" |World").to_json
|
||||||
|
data << "\r\n\r\ndata: "
|
||||||
|
data << gemini_mock.response(" Sam").to_json
|
||||||
|
|
||||||
|
split = data.split("|")
|
||||||
|
|
||||||
|
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-flash")
|
||||||
|
url =
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:streamGenerateContent?alt=sse&key=ABC"
|
||||||
|
|
||||||
|
output = +""
|
||||||
|
gemini_mock.with_chunk_array_support do
|
||||||
|
stub_request(:post, url).to_return(status: 200, body: split)
|
||||||
|
llm.generate("Hello", user: user) { |partial| output << partial }
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "when using streaming mode" do
|
expect(output).to eq("Hello World Sam")
|
||||||
context "with simple prompts" do
|
|
||||||
it "completes a trivial prompt and logs the response" do
|
|
||||||
compliance.streaming_mode_simple_prompt(gemini_mock)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with tools" do
|
|
||||||
it "returns a function invocation" do
|
|
||||||
compliance.streaming_mode_tools(gemini_mock)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue