diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb
index f7cabefc..78e8567f 100644
--- a/app/models/ai_api_audit_log.rb
+++ b/app/models/ai_api_audit_log.rb
@@ -5,6 +5,7 @@ class AiApiAuditLog < ActiveRecord::Base
OpenAI = 1
Anthropic = 2
HuggingFaceTextGeneration = 3
+ Gemini = 4
end
end
diff --git a/config/settings.yml b/config/settings.yml
index ce62a1de..757f1871 100644
--- a/config/settings.yml
+++ b/config/settings.yml
@@ -147,6 +147,9 @@ discourse_ai:
ai_cloudflare_workers_api_token:
default: ""
secret: true
+ ai_gemini_api_key:
+ default: ""
+ secret: true
composer_ai_helper_enabled:
default: false
@@ -170,6 +173,7 @@ discourse_ai:
- claude-2
- stable-beluga-2
- Llama2-chat-hf
+ - gemini-pro
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
@@ -233,6 +237,7 @@ discourse_ai:
- gpt-4
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
+ - gemini-pro
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_key:
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
new file mode 100644
index 00000000..49fa716e
--- /dev/null
+++ b/lib/completions/dialects/gemini.rb
@@ -0,0 +1,38 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class Gemini
+ def self.can_translate?(model_name)
+ %w[gemini-pro].include?(model_name)
+ end
+
+ def translate(generic_prompt)
+ gemini_prompt = [
+ {
+ role: "user",
+ parts: {
+ text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
+ },
+ },
+ { role: "model", parts: { text: "Ok." } },
+ ]
+
+ if generic_prompt[:examples]
+ generic_prompt[:examples].each do |example_pair|
+ gemini_prompt << { role: "user", parts: { text: example_pair.first } }
+ gemini_prompt << { role: "model", parts: { text: example_pair.second } }
+ end
+ end
+
+ gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 43468c15..61846e23 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -15,6 +15,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
+ DiscourseAi::Completions::Endpoints::Gemini,
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(model_name)
end
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
new file mode 100644
index 00000000..7738085c
--- /dev/null
+++ b/lib/completions/endpoints/gemini.rb
@@ -0,0 +1,62 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Endpoints
+ class Gemini < Base
+ def self.can_contact?(model_name)
+ %w[gemini-pro].include?(model_name)
+ end
+
+ def default_options
+ {}
+ end
+
+ def provider_id
+ AiApiAuditLog::Provider::Gemini
+ end
+
+ private
+
+ def model_uri
+ url =
+ "https://generativelanguage.googleapis.com/v1beta/models/#{model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
+
+ URI(url)
+ end
+
+ def prepare_payload(prompt, model_params)
+ default_options.merge(model_params).merge(contents: prompt)
+ end
+
+ def prepare_request(payload)
+ headers = { "Content-Type" => "application/json" }
+
+ Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
+ end
+
+ def extract_completion_from(response_raw)
+ if @streaming_mode
+ parsed = response_raw
+ else
+ parsed = JSON.parse(response_raw, symbolize_names: true)
+ end
+
+ completion = dig_text(parsed).to_s
+ end
+
+ def partials_from(decoded_chunk)
+ JSON.parse(decoded_chunk, symbolize_names: true)
+ end
+
+ def extract_prompt_for_tokenizer(prompt)
+ prompt.to_s
+ end
+
+ def dig_text(response)
+ response.dig(:candidates, 0, :content, :parts, 0, :text)
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index e6afd9da..f2cbc0e2 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -29,6 +29,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
+ DiscourseAi::Completions::Dialects::Gemini,
]
dialect =
diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb
index 66d359b3..bb553366 100644
--- a/lib/summarization/entry_point.rb
+++ b/lib/summarization/entry_point.rb
@@ -16,6 +16,7 @@ module DiscourseAi
"StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
+ Models::Gemini.new("gemini-pro", max_tokens: 32_768),
]
foldable_models.each do |model|
diff --git a/lib/summarization/models/gemini.rb b/lib/summarization/models/gemini.rb
new file mode 100644
index 00000000..2f0550ac
--- /dev/null
+++ b/lib/summarization/models/gemini.rb
@@ -0,0 +1,25 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Summarization
+ module Models
+ class Gemini < Base
+ def display_name
+ "Google Gemini #{model}"
+ end
+
+ def correctly_configured?
+ SiteSetting.ai_gemini_api_key.present?
+ end
+
+ def configuration_hint
+ I18n.t(
+ "discourse_ai.summarization.configuration_hint",
+ count: 1,
+ setting: "ai_gemini_api_key",
+ )
+ end
+ end
+ end
+ end
+end
diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb
index 16a15c05..dedc4209 100644
--- a/lib/summarization/strategies/fold_content.rb
+++ b/lib/summarization/strategies/fold_content.rb
@@ -151,9 +151,8 @@ module DiscourseAi
For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3)
TEXT
- insts += "The discussion title is: #{opts[:content_title]}.\n" if opts[:content_title]
-
prompt = { insts: insts, input: <<~TEXT }
+ #{opts[:content_title].present? ? "The discussion title is: " + opts[:content_title] + ".\n" : ""}
Here are the posts, inside XML tags:
diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb
new file mode 100644
index 00000000..84aec2d3
--- /dev/null
+++ b/spec/lib/completions/dialects/gemini_spec.rb
@@ -0,0 +1,65 @@
+# frozen_string_literal: true
+
+RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
+ subject(:dialect) { described_class.new }
+
+ let(:prompt) do
+ {
+ insts: <<~TEXT,
+ I want you to act as a title generator for written pieces. I will provide you with a text,
+ and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
+ and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
+ TEXT
+ input: <<~TEXT,
+ Here is the text, inside XML tags:
+
+ To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
+ discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
+ defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
+
+ Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
+ a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
+ slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
+ dies so that a scene may be repeated.
+
+ TEXT
+ post_insts:
+ "Please put the translation between tags and separate each title with a comma.",
+ }
+ end
+
+ describe "#translate" do
+ it "translates a prompt written in our generic format to the Gemini format" do
+ gemini_version = [
+ { role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
+ { role: "model", parts: { text: "Ok." } },
+ { role: "user", parts: { text: prompt[:input] } },
+ ]
+
+ translated = dialect.translate(prompt)
+
+ expect(translated).to eq(gemini_version)
+ end
+
+ it "include examples in the Gemini version" do
+ prompt[:examples] = [
+ [
+ "In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.",
+ "The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.",
+ ],
+ ]
+
+ gemini_version = [
+ { role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } },
+ { role: "model", parts: { text: "Ok." } },
+ { role: "user", parts: { text: prompt[:examples][0][0] } },
+ { role: "model", parts: { text: prompt[:examples][0][1] } },
+ { role: "user", parts: { text: prompt[:input] } },
+ ]
+
+ translated = dialect.translate(prompt)
+
+ expect(translated).to contain_exactly(*gemini_version)
+ end
+ end
+end
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
new file mode 100644
index 00000000..a9431dde
--- /dev/null
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -0,0 +1,101 @@
+# frozen_string_literal: true
+
+require_relative "endpoint_examples"
+
+RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
+ subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
+
+ let(:model_name) { "gemini-pro" }
+ let(:prompt) do
+ [
+ { role: "system", content: "You are a helpful bot." },
+ { role: "user", content: "Write 3 words" },
+ ]
+ end
+
+ let(:request_body) { model.default_options.merge(contents: prompt).to_json }
+ let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
+
+ def response(content)
+ {
+ candidates: [
+ {
+ content: {
+ parts: [{ text: content }],
+ role: "model",
+ },
+ finishReason: "STOP",
+ index: 0,
+ safetyRatings: [
+ { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
+ ],
+ },
+ ],
+ promptFeedback: {
+ safetyRatings: [
+ { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
+ ],
+ },
+ }
+ end
+
+ def stub_response(prompt, response_text)
+ WebMock
+ .stub_request(
+ :post,
+ "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
+ )
+ .with(body: { contents: prompt })
+ .to_return(status: 200, body: JSON.dump(response(response_text)))
+ end
+
+ def stream_line(delta, finish_reason: nil)
+ {
+ candidates: [
+ {
+ content: {
+ parts: [{ text: delta }],
+ role: "model",
+ },
+ finishReason: finish_reason,
+ index: 0,
+ safetyRatings: [
+ { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
+ { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
+ ],
+ },
+ ],
+ }.to_json
+ end
+
+ def stub_streamed_response(prompt, deltas)
+ chunks =
+ deltas.each_with_index.map do |_, index|
+ if index == (deltas.length - 1)
+ stream_line(deltas[index], finish_reason: "STOP")
+ else
+ stream_line(deltas[index])
+ end
+ end
+
+ chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
+
+ WebMock
+ .stub_request(
+ :post,
+ "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
+ )
+ .with(body: model.default_options.merge(contents: prompt).to_json)
+ .to_return(status: 200, body: chunks)
+ end
+
+ it_behaves_like "an endpoint that can communicate with a completion service"
+end