diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index f7cabefc..78e8567f 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -5,6 +5,7 @@ class AiApiAuditLog < ActiveRecord::Base OpenAI = 1 Anthropic = 2 HuggingFaceTextGeneration = 3 + Gemini = 4 end end diff --git a/config/settings.yml b/config/settings.yml index ce62a1de..757f1871 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -147,6 +147,9 @@ discourse_ai: ai_cloudflare_workers_api_token: default: "" secret: true + ai_gemini_api_key: + default: "" + secret: true composer_ai_helper_enabled: default: false @@ -170,6 +173,7 @@ discourse_ai: - claude-2 - stable-beluga-2 - Llama2-chat-hf + - gemini-pro ai_helper_custom_prompts_allowed_groups: client: true type: group_list @@ -233,6 +237,7 @@ discourse_ai: - gpt-4 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 + - gemini-pro ai_summarization_discourse_service_api_endpoint: "" ai_summarization_discourse_service_api_key: diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb new file mode 100644 index 00000000..49fa716e --- /dev/null +++ b/lib/completions/dialects/gemini.rb @@ -0,0 +1,38 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Gemini + def self.can_translate?(model_name) + %w[gemini-pro].include?(model_name) + end + + def translate(generic_prompt) + gemini_prompt = [ + { + role: "user", + parts: { + text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"), + }, + }, + { role: "model", parts: { text: "Ok." } }, + ] + + if generic_prompt[:examples] + generic_prompt[:examples].each do |example_pair| + gemini_prompt << { role: "user", parts: { text: example_pair.first } } + gemini_prompt << { role: "model", parts: { text: example_pair.second } } + end + end + + gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } } + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer + end + end + end + end +end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 43468c15..61846e23 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -15,6 +15,7 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Anthropic, DiscourseAi::Completions::Endpoints::OpenAi, DiscourseAi::Completions::Endpoints::HuggingFace, + DiscourseAi::Completions::Endpoints::Gemini, ].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek| ek.can_contact?(model_name) end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb new file mode 100644 index 00000000..7738085c --- /dev/null +++ b/lib/completions/endpoints/gemini.rb @@ -0,0 +1,62 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Gemini < Base + def self.can_contact?(model_name) + %w[gemini-pro].include?(model_name) + end + + def default_options + {} + end + + def provider_id + AiApiAuditLog::Provider::Gemini + end + + private + + def model_uri + url = + "https://generativelanguage.googleapis.com/v1beta/models/#{model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}" + + URI(url) + end + + def prepare_payload(prompt, model_params) + default_options.merge(model_params).merge(contents: prompt) + end + + def prepare_request(payload) + headers = { "Content-Type" => "application/json" } + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def extract_completion_from(response_raw) + if @streaming_mode + parsed = response_raw + else + parsed = JSON.parse(response_raw, symbolize_names: true) + end + + completion = dig_text(parsed).to_s + end + + def partials_from(decoded_chunk) + JSON.parse(decoded_chunk, symbolize_names: true) + end + + def extract_prompt_for_tokenizer(prompt) + prompt.to_s + end + + def dig_text(response) + response.dig(:candidates, 0, :content, :parts, 0, :text) + end + end + end + end +end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index e6afd9da..f2cbc0e2 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -29,6 +29,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Llama2Classic, DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::OrcaStyle, + DiscourseAi::Completions::Dialects::Gemini, ] dialect = diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 66d359b3..bb553366 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -16,6 +16,7 @@ module DiscourseAi "StableBeluga2", max_tokens: SiteSetting.ai_hugging_face_token_limit, ), + Models::Gemini.new("gemini-pro", max_tokens: 32_768), ] foldable_models.each do |model| diff --git a/lib/summarization/models/gemini.rb b/lib/summarization/models/gemini.rb new file mode 100644 index 00000000..2f0550ac --- /dev/null +++ b/lib/summarization/models/gemini.rb @@ -0,0 +1,25 @@ +# frozen_string_literal: true + +module DiscourseAi + module Summarization + module Models + class Gemini < Base + def display_name + "Google Gemini #{model}" + end + + def correctly_configured? + SiteSetting.ai_gemini_api_key.present? + end + + def configuration_hint + I18n.t( + "discourse_ai.summarization.configuration_hint", + count: 1, + setting: "ai_gemini_api_key", + ) + end + end + end + end +end diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb index 16a15c05..dedc4209 100644 --- a/lib/summarization/strategies/fold_content.rb +++ b/lib/summarization/strategies/fold_content.rb @@ -151,9 +151,8 @@ module DiscourseAi For example, a link to the 3rd post in the topic would be [post 3](#{opts[:resource_path]}/3) TEXT - insts += "The discussion title is: #{opts[:content_title]}.\n" if opts[:content_title] - prompt = { insts: insts, input: <<~TEXT } + #{opts[:content_title].present? ? "The discussion title is: " + opts[:content_title] + ".\n" : ""} Here are the posts, inside XML tags: diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb new file mode 100644 index 00000000..84aec2d3 --- /dev/null +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -0,0 +1,65 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::Dialects::Gemini do + subject(:dialect) { described_class.new } + + let(:prompt) do + { + insts: <<~TEXT, + I want you to act as a title generator for written pieces. I will provide you with a text, + and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words, + and ensure that the meaning is maintained. Replies will utilize the language type of the topic. + TEXT + input: <<~TEXT, + Here is the text, inside XML tags: + + To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, + discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer + defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry. + + Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires, + a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and + slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he + dies so that a scene may be repeated. + + TEXT + post_insts: + "Please put the translation between tags and separate each title with a comma.", + } + end + + describe "#translate" do + it "translates a prompt written in our generic format to the Gemini format" do + gemini_version = [ + { role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } }, + { role: "model", parts: { text: "Ok." } }, + { role: "user", parts: { text: prompt[:input] } }, + ] + + translated = dialect.translate(prompt) + + expect(translated).to eq(gemini_version) + end + + it "include examples in the Gemini version" do + prompt[:examples] = [ + [ + "In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.", + "The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.", + ], + ] + + gemini_version = [ + { role: "user", parts: { text: [prompt[:insts], prompt[:post_insts]].join("\n") } }, + { role: "model", parts: { text: "Ok." } }, + { role: "user", parts: { text: prompt[:examples][0][0] } }, + { role: "model", parts: { text: prompt[:examples][0][1] } }, + { role: "user", parts: { text: prompt[:input] } }, + ] + + translated = dialect.translate(prompt) + + expect(translated).to contain_exactly(*gemini_version) + end + end +end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb new file mode 100644 index 00000000..a9431dde --- /dev/null +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -0,0 +1,101 @@ +# frozen_string_literal: true + +require_relative "endpoint_examples" + +RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do + subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } + + let(:model_name) { "gemini-pro" } + let(:prompt) do + [ + { role: "system", content: "You are a helpful bot." }, + { role: "user", content: "Write 3 words" }, + ] + end + + let(:request_body) { model.default_options.merge(contents: prompt).to_json } + let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json } + + def response(content) + { + candidates: [ + { + content: { + parts: [{ text: content }], + role: "model", + }, + finishReason: "STOP", + index: 0, + safetyRatings: [ + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + ], + }, + ], + promptFeedback: { + safetyRatings: [ + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + ], + }, + } + end + + def stub_response(prompt, response_text) + WebMock + .stub_request( + :post, + "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", + ) + .with(body: { contents: prompt }) + .to_return(status: 200, body: JSON.dump(response(response_text))) + end + + def stream_line(delta, finish_reason: nil) + { + candidates: [ + { + content: { + parts: [{ text: delta }], + role: "model", + }, + finishReason: finish_reason, + index: 0, + safetyRatings: [ + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + ], + }, + ], + }.to_json + end + + def stub_streamed_response(prompt, deltas) + chunks = + deltas.each_with_index.map do |_, index| + if index == (deltas.length - 1) + stream_line(deltas[index], finish_reason: "STOP") + else + stream_line(deltas[index]) + end + end + + chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]") + + WebMock + .stub_request( + :post, + "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", + ) + .with(body: model.default_options.merge(contents: prompt).to_json) + .to_return(status: 200, body: chunks) + end + + it_behaves_like "an endpoint that can communicate with a completion service" +end