FEATURE: Support for Gemini in AiHelper / Search / Summarization (#358)
This commit is contained in:
parent
831559662e
commit
83744bf192
|
@ -5,6 +5,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
OpenAI = 1
|
OpenAI = 1
|
||||||
Anthropic = 2
|
Anthropic = 2
|
||||||
HuggingFaceTextGeneration = 3
|
HuggingFaceTextGeneration = 3
|
||||||
|
Gemini = 4
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -147,6 +147,9 @@ discourse_ai:
|
||||||
ai_cloudflare_workers_api_token:
|
ai_cloudflare_workers_api_token:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
ai_gemini_api_key:
|
||||||
|
default: ""
|
||||||
|
secret: true
|
||||||
|
|
||||||
composer_ai_helper_enabled:
|
composer_ai_helper_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
@ -170,6 +173,7 @@ discourse_ai:
|
||||||
- claude-2
|
- claude-2
|
||||||
- stable-beluga-2
|
- stable-beluga-2
|
||||||
- Llama2-chat-hf
|
- Llama2-chat-hf
|
||||||
|
- gemini-pro
|
||||||
ai_helper_custom_prompts_allowed_groups:
|
ai_helper_custom_prompts_allowed_groups:
|
||||||
client: true
|
client: true
|
||||||
type: group_list
|
type: group_list
|
||||||
|
@ -233,6 +237,7 @@ discourse_ai:
|
||||||
- gpt-4
|
- gpt-4
|
||||||
- StableBeluga2
|
- StableBeluga2
|
||||||
- Upstage-Llama-2-*-instruct-v2
|
- Upstage-Llama-2-*-instruct-v2
|
||||||
|
- gemini-pro
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: ""
|
ai_summarization_discourse_service_api_endpoint: ""
|
||||||
ai_summarization_discourse_service_api_key:
|
ai_summarization_discourse_service_api_key:
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class Gemini
|
||||||
|
def self.can_translate?(model_name)
|
||||||
|
%w[gemini-pro].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def translate(generic_prompt)
|
||||||
|
gemini_prompt = [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
parts: {
|
||||||
|
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ role: "model", parts: { text: "Ok." } },
|
||||||
|
]
|
||||||
|
|
||||||
|
if generic_prompt[:examples]
|
||||||
|
generic_prompt[:examples].each do |example_pair|
|
||||||
|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
||||||
|
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -15,6 +15,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||||
DiscourseAi::Completions::Endpoints::HuggingFace,
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||||
|
DiscourseAi::Completions::Endpoints::Gemini,
|
||||||
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||||
ek.can_contact?(model_name)
|
ek.can_contact?(model_name)
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Endpoints
|
||||||
|
class Gemini < Base
|
||||||
|
def self.can_contact?(model_name)
|
||||||
|
%w[gemini-pro].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def default_options
|
||||||
|
{}
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_id
|
||||||
|
AiApiAuditLog::Provider::Gemini
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def model_uri
|
||||||
|
url =
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/#{model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
|
||||||
|
|
||||||
|
URI(url)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_payload(prompt, model_params)
|
||||||
|
default_options.merge(model_params).merge(contents: prompt)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_request(payload)
|
||||||
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
end
|
||||||
|
|
||||||
|
def extract_completion_from(response_raw)
|
||||||
|
if @streaming_mode
|
||||||
|
parsed = response_raw
|
||||||
|
else
|
||||||
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
end
|
||||||
|
|
||||||
|
completion = dig_text(parsed).to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
def partials_from(decoded_chunk)
|
||||||
|
JSON.parse(decoded_chunk, symbolize_names: true)
|
||||||
|
end
|
||||||
|
|
||||||
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
|
prompt.to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
def dig_text(response)
|
||||||
|
response.dig(:candidates, 0, :content, :parts, 0, :text)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -29,6 +29,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
|
DiscourseAi::Completions::Dialects::Gemini,
|
||||||
]
|
]
|
||||||
|
|
||||||
dialect =
|
dialect =
|
||||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
||||||
"StableBeluga2",
|
"StableBeluga2",
|
||||||
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
||||||
),
|
),
|
||||||
|
Models::Gemini.new("gemini-pro", max_tokens: 32_768),
|
||||||
]
|
]
|
||||||
|
|
||||||
foldable_models.each do |model|
|
foldable_models.each do |model|
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Summarization
|
||||||
|
module Models
|
||||||
|
class Gemini < Base
|
||||||
|
def display_name
|
||||||
|
"Google Gemini #{model}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?
|
||||||
|
SiteSetting.ai_gemini_api_key.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def configuration_hint
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.summarization.configuration_hint",
|
||||||
|
count: 1,
|
||||||
|
setting: "ai_gemini_api_key",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -151,9 +151,8 @@ module DiscourseAi
|
||||||
For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3)
|
For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3)
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
insts += "The discussion title is: #{opts[:content_title]}.\n" if opts[:content_title]
|
|
||||||
|
|
||||||
prompt = { insts: insts, input: <<~TEXT }
|
prompt = { insts: insts, input: <<~TEXT }
|
||||||
|
#{opts[:content_title].present? ? "The discussion title is: " + opts[:content_title] + ".\n" : ""}
|
||||||
Here are the posts, inside <input></input> XML tags:
|
Here are the posts, inside <input></input> XML tags:
|
||||||
|
|
||||||
<input>
|
<input>
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
|
subject(:dialect) { described_class.new }
|
||||||
|
|
||||||
|
let(:prompt) do
|
||||||
|
{
|
||||||
|
insts: <<~TEXT,
|
||||||
|
I want you to act as a title generator for written pieces. I will provide you with a text,
|
||||||
|
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
|
||||||
|
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
|
||||||
|
TEXT
|
||||||
|
input: <<~TEXT,
|
||||||
|
Here is the text, inside <input></input> XML tags:
|
||||||
|
<input>
|
||||||
|
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||||
|
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||||
|
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
|
||||||
|
|
||||||
|
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
|
||||||
|
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
|
||||||
|
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
|
||||||
|
dies so that a scene may be repeated.
|
||||||
|
</input>
|
||||||
|
TEXT
|
||||||
|
post_insts:
|
||||||
|
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#translate" do
|
||||||
|
it "translates a prompt written in our generic format to the Gemini format" do
|
||||||
|
gemini_version = [
|
||||||
|
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
|
||||||
|
{ role: "model", parts: { text: "Ok." } },
|
||||||
|
{ role: "user", parts: { text: prompt[:input] } },
|
||||||
|
]
|
||||||
|
|
||||||
|
translated = dialect.translate(prompt)
|
||||||
|
|
||||||
|
expect(translated).to eq(gemini_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include examples in the Gemini version" do
|
||||||
|
prompt[:examples] = [
|
||||||
|
[
|
||||||
|
"<input>In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.</input>",
|
||||||
|
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
gemini_version = [
|
||||||
|
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
|
||||||
|
{ role: "model", parts: { text: "Ok." } },
|
||||||
|
{ role: "user", parts: { text: prompt[:examples][0][0] } },
|
||||||
|
{ role: "model", parts: { text: prompt[:examples][0][1] } },
|
||||||
|
{ role: "user", parts: { text: prompt[:input] } },
|
||||||
|
]
|
||||||
|
|
||||||
|
translated = dialect.translate(prompt)
|
||||||
|
|
||||||
|
expect(translated).to contain_exactly(*gemini_version)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,101 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "endpoint_examples"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||||
|
|
||||||
|
let(:model_name) { "gemini-pro" }
|
||||||
|
let(:prompt) do
|
||||||
|
[
|
||||||
|
{ role: "system", content: "You are a helpful bot." },
|
||||||
|
{ role: "user", content: "Write 3 words" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||||
|
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||||
|
|
||||||
|
def response(content)
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: content }],
|
||||||
|
role: "model",
|
||||||
|
},
|
||||||
|
finishReason: "STOP",
|
||||||
|
index: 0,
|
||||||
|
safetyRatings: [
|
||||||
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
promptFeedback: {
|
||||||
|
safetyRatings: [
|
||||||
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_response(prompt, response_text)
|
||||||
|
WebMock
|
||||||
|
.stub_request(
|
||||||
|
:post,
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
|
)
|
||||||
|
.with(body: { contents: prompt })
|
||||||
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
|
end
|
||||||
|
|
||||||
|
def stream_line(delta, finish_reason: nil)
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: delta }],
|
||||||
|
role: "model",
|
||||||
|
},
|
||||||
|
finishReason: finish_reason,
|
||||||
|
index: 0,
|
||||||
|
safetyRatings: [
|
||||||
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_response(prompt, deltas)
|
||||||
|
chunks =
|
||||||
|
deltas.each_with_index.map do |_, index|
|
||||||
|
if index == (deltas.length - 1)
|
||||||
|
stream_line(deltas[index], finish_reason: "STOP")
|
||||||
|
else
|
||||||
|
stream_line(deltas[index])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(
|
||||||
|
:post,
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
|
)
|
||||||
|
.with(body: model.default_options.merge(contents: prompt).to_json)
|
||||||
|
.to_return(status: 200, body: chunks)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||||
|
end
|
Loading…
Reference in New Issue