FEATURE: Support for Gemini in AiHelper / Search / Summarization (#358)

This commit is contained in:
Rafael dos Santos Silva 2023-12-15 14:32:01 -03:00 committed by GitHub
parent 831559662e
commit 83744bf192
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 301 additions and 2 deletions

View File

@ -5,6 +5,7 @@ class AiApiAuditLog < ActiveRecord::Base
OpenAI = 1
Anthropic = 2
HuggingFaceTextGeneration = 3
Gemini = 4
end
end

View File

@ -147,6 +147,9 @@ discourse_ai:
ai_cloudflare_workers_api_token:
default: ""
secret: true
ai_gemini_api_key:
default: ""
secret: true
composer_ai_helper_enabled:
default: false
@ -170,6 +173,7 @@ discourse_ai:
- claude-2
- stable-beluga-2
- Llama2-chat-hf
- gemini-pro
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
@ -233,6 +237,7 @@ discourse_ai:
- gpt-4
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- gemini-pro
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_key:

View File

@ -0,0 +1,38 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Gemini
def self.can_translate?(model_name)
%w[gemini-pro].include?(model_name)
end
def translate(generic_prompt)
gemini_prompt = [
{
role: "user",
parts: {
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
end
end
end
end

View File

@ -15,6 +15,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(model_name)
end

View File

@ -0,0 +1,62 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Gemini < Base
def self.can_contact?(model_name)
%w[gemini-pro].include?(model_name)
end
def default_options
{}
end
def provider_id
AiApiAuditLog::Provider::Gemini
end
private
def model_uri
url =
"https://generativelanguage.googleapis.com/v1beta/models/#{model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
URI(url)
end
def prepare_payload(prompt, model_params)
default_options.merge(model_params).merge(contents: prompt)
end
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
if @streaming_mode
parsed = response_raw
else
parsed = JSON.parse(response_raw, symbolize_names: true)
end
completion = dig_text(parsed).to_s
end
def partials_from(decoded_chunk)
JSON.parse(decoded_chunk, symbolize_names: true)
end
def extract_prompt_for_tokenizer(prompt)
prompt.to_s
end
def dig_text(response)
response.dig(:candidates, 0, :content, :parts, 0, :text)
end
end
end
end
end

View File

@ -29,6 +29,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
]
dialect =

View File

@ -16,6 +16,7 @@ module DiscourseAi
"StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Gemini.new("gemini-pro", max_tokens: 32_768),
]
foldable_models.each do |model|

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class Gemini < Base
def display_name
"Google Gemini #{model}"
end
def correctly_configured?
SiteSetting.ai_gemini_api_key.present?
end
def configuration_hint
I18n.t(
"discourse_ai.summarization.configuration_hint",
count: 1,
setting: "ai_gemini_api_key",
)
end
end
end
end
end

View File

@ -151,9 +151,8 @@ module DiscourseAi
For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3)
TEXT
insts += "The discussion title is: #{opts[:content_title]}.\n" if opts[:content_title]
prompt = { insts: insts, input: <<~TEXT }
#{opts[:content_title].present? ? "The discussion title is: " + opts[:content_title] + ".\n" : ""}
Here are the posts, inside <input></input> XML tags:
<input>

View File

@ -0,0 +1,65 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
subject(:dialect) { described_class.new }
let(:prompt) do
{
insts: <<~TEXT,
I want you to act as a title generator for written pieces. I will provide you with a text,
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
TEXT
input: <<~TEXT,
Here is the text, inside <input></input> XML tags:
<input>
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
dies so that a scene may be repeated.
</input>
TEXT
post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
}
end
describe "#translate" do
it "translates a prompt written in our generic format to the Gemini format" do
gemini_version = [
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
{ role: "model", parts: { text: "Ok." } },
{ role: "user", parts: { text: prompt[:input] } },
]
translated = dialect.translate(prompt)
expect(translated).to eq(gemini_version)
end
it "include examples in the Gemini version" do
prompt[:examples] = [
[
"<input>In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.</input>",
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
],
]
gemini_version = [
{ role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
{ role: "model", parts: { text: "Ok." } },
{ role: "user", parts: { text: prompt[:examples][0][0] } },
{ role: "model", parts: { text: prompt[:examples][0][1] } },
{ role: "user", parts: { text: prompt[:input] } },
]
translated = dialect.translate(prompt)
expect(translated).to contain_exactly(*gemini_version)
end
end
end

View File

@ -0,0 +1,101 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" }
let(:prompt) do
[
{ role: "system", content: "You are a helpful bot." },
{ role: "user", content: "Write 3 words" },
]
end
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
def response(content)
{
candidates: [
{
content: {
parts: [{ text: content }],
role: "model",
},
finishReason: "STOP",
index: 0,
safetyRatings: [
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
],
},
],
promptFeedback: {
safetyRatings: [
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
],
},
}
end
def stub_response(prompt, response_text)
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: { contents: prompt })
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta, finish_reason: nil)
{
candidates: [
{
content: {
parts: [{ text: delta }],
role: "model",
},
finishReason: finish_reason,
index: 0,
safetyRatings: [
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
],
},
],
}.to_json
end
def stub_streamed_response(prompt, deltas)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "STOP")
else
stream_line(deltas[index])
end
end
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: model.default_options.merge(contents: prompt).to_json)
.to_return(status: 200, body: chunks)
end
it_behaves_like "an endpoint that can communicate with a completion service"
end