FEATURE: Support for Gemini in AiHelper / Search / Summarization (#358)
This commit is contained in:
parent
831559662e
commit
83744bf192
|
@ -5,6 +5,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
|||
OpenAI = 1
|
||||
Anthropic = 2
|
||||
HuggingFaceTextGeneration = 3
|
||||
Gemini = 4
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -147,6 +147,9 @@ discourse_ai:
|
|||
ai_cloudflare_workers_api_token:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_gemini_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
|
||||
composer_ai_helper_enabled:
|
||||
default: false
|
||||
|
@ -170,6 +173,7 @@ discourse_ai:
|
|||
- claude-2
|
||||
- stable-beluga-2
|
||||
- Llama2-chat-hf
|
||||
- gemini-pro
|
||||
ai_helper_custom_prompts_allowed_groups:
|
||||
client: true
|
||||
type: group_list
|
||||
|
@ -233,6 +237,7 @@ discourse_ai:
|
|||
- gpt-4
|
||||
- StableBeluga2
|
||||
- Upstage-Llama-2-*-instruct-v2
|
||||
- gemini-pro
|
||||
|
||||
ai_summarization_discourse_service_api_endpoint: ""
|
||||
ai_summarization_discourse_service_api_key:
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Gemini
|
||||
def self.can_translate?(model_name)
|
||||
%w[gemini-pro].include?(model_name)
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
gemini_prompt = [
|
||||
{
|
||||
role: "user",
|
||||
parts: {
|
||||
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
||||
},
|
||||
},
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
]
|
||||
|
||||
if generic_prompt[:examples]
|
||||
generic_prompt[:examples].each do |example_pair|
|
||||
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
||||
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
||||
end
|
||||
end
|
||||
|
||||
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -15,6 +15,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||
DiscourseAi::Completions::Endpoints::Gemini,
|
||||
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||
ek.can_contact?(model_name)
|
||||
end
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Endpoints
|
||||
class Gemini < Base
|
||||
def self.can_contact?(model_name)
|
||||
%w[gemini-pro].include?(model_name)
|
||||
end
|
||||
|
||||
def default_options
|
||||
{}
|
||||
end
|
||||
|
||||
def provider_id
|
||||
AiApiAuditLog::Provider::Gemini
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def model_uri
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
|
||||
|
||||
URI(url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
default_options.merge(model_params).merge(contents: prompt)
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
if @streaming_mode
|
||||
parsed = response_raw
|
||||
else
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
end
|
||||
|
||||
completion = dig_text(parsed).to_s
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
prompt.to_s
|
||||
end
|
||||
|
||||
def dig_text(response)
|
||||
response.dig(:candidates, 0, :content, :parts, 0, :text)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -29,6 +29,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
]
|
||||
|
||||
dialect =
|
||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
|||
"StableBeluga2",
|
||||
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
||||
),
|
||||
Models::Gemini.new("gemini-pro", max_tokens: 32_768),
|
||||
]
|
||||
|
||||
foldable_models.each do |model|
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Summarization
|
||||
module Models
|
||||
class Gemini < Base
|
||||
def display_name
|
||||
"Google Gemini #{model}"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_gemini_api_key.present?
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
I18n.t(
|
||||
"discourse_ai.summarization.configuration_hint",
|
||||
count: 1,
|
||||
setting: "ai_gemini_api_key",
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -151,9 +151,8 @@ module DiscourseAi
|
|||
For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3)
|
||||
TEXT
|
||||
|
||||
insts += "The discussion title is: #{opts[:content_title]}.\n" if opts[:content_title]
|
||||
|
||||
prompt = { insts: insts, input: <<~TEXT }
|
||||
#{opts[:content_title].present? ? "The discussion title is: " + opts[:content_title] + ".\n" : ""}
|
||||
Here are the posts, inside <input></input> XML tags:
|
||||
|
||||
<input>
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
subject(:dialect) { described_class.new }
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
insts: <<~TEXT,
|
||||
I want you to act as a title generator for written pieces. I will provide you with a text,
|
||||
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
|
||||
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
|
||||
TEXT
|
||||
input: <<~TEXT,
|
||||
Here is the text, inside <input></input> XML tags:
|
||||
<input>
|
||||
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
|
||||
|
||||
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
|
||||
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
|
||||
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
|
||||
dies so that a scene may be repeated.
|
||||
</input>
|
||||
TEXT
|
||||
post_insts:
|
||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||
}
|
||||
end
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the Gemini format" do
|
||||
gemini_version = [
|
||||
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: { text: prompt[:input] } },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
|
||||
expect(translated).to eq(gemini_version)
|
||||
end
|
||||
|
||||
it "include examples in the Gemini version" do
|
||||
prompt[:examples] = [
|
||||
[
|
||||
"<input>In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.</input>",
|
||||
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
|
||||
],
|
||||
]
|
||||
|
||||
gemini_version = [
|
||||
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: { text: prompt[:examples][0][0] } },
|
||||
{ role: "model", parts: { text: prompt[:examples][0][1] } },
|
||||
{ role: "user", parts: { text: prompt[:input] } },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
|
||||
expect(translated).to contain_exactly(*gemini_version)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,101 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
let(:model_name) { "gemini-pro" }
|
||||
let(:prompt) do
|
||||
[
|
||||
{ role: "system", content: "You are a helpful bot." },
|
||||
{ role: "user", content: "Write 3 words" },
|
||||
]
|
||||
end
|
||||
|
||||
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||
|
||||
def response(content)
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: content }],
|
||||
role: "model",
|
||||
},
|
||||
finishReason: "STOP",
|
||||
index: 0,
|
||||
safetyRatings: [
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||
],
|
||||
},
|
||||
],
|
||||
promptFeedback: {
|
||||
safetyRatings: [
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||
],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: { contents: prompt })
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: delta }],
|
||||
role: "model",
|
||||
},
|
||||
finishReason: finish_reason,
|
||||
index: 0,
|
||||
safetyRatings: [
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||
],
|
||||
},
|
||||
],
|
||||
}.to_json
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: "STOP")
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: model.default_options.merge(contents: prompt).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
Loading…
Reference in New Issue