diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 76a3b3b2..d8cb3074 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -105,6 +105,22 @@ module DiscourseAi status: 502 end + def caption_image + image_url = params[:image_url] + raise Discourse::InvalidParameters.new(:image_url) if !image_url + + image = Upload.where(url: params[:image_url]) + + hijack do + caption = + DiscourseAi::AiHelper::Assistant.new.generate_image_caption(image_url, current_user) + render json: { caption: caption }, status: 200 + end + rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed, Net::HTTPBadResponse + render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), + status: 502 + end + private def get_text_param! diff --git a/assets/javascripts/discourse/connectors/editor-preview/ai-image-caption-container.gjs b/assets/javascripts/discourse/connectors/editor-preview/ai-image-caption-container.gjs new file mode 100644 index 00000000..9d14061a --- /dev/null +++ b/assets/javascripts/discourse/connectors/editor-preview/ai-image-caption-container.gjs @@ -0,0 +1,80 @@ +import Component from "@glimmer/component"; +import { tracked } from "@glimmer/tracking"; +import { fn } from "@ember/helper"; +import { on } from "@ember/modifier"; +import { action } from "@ember/object"; +import { inject as service } from "@ember/service"; +import ConditionalLoadingSpinner from "discourse/components/conditional-loading-spinner"; +import DButton from "discourse/components/d-button"; +import autoFocus from "discourse/modifiers/auto-focus"; +import icon from "discourse-common/helpers/d-icon"; +import i18n from "discourse-common/helpers/i18n"; +import { IMAGE_MARKDOWN_REGEX } from "../../lib/utilities"; + +export default class AiImageCaptionContainer extends Component { + @service imageCaptionPopup; + @service appEvents; + @service composer; + @tracked newCaption = this.imageCaptionPopup.newCaption || ""; + + @action + updateCaption(event) { + event.preventDefault(); + this.newCaption = event.target.value; + } + + @action + saveCaption() { + const index = this.imageCaptionPopup.imageIndex; + const matchingPlaceholder = + this.composer.model.reply.match(IMAGE_MARKDOWN_REGEX); + const match = matchingPlaceholder[index]; + const replacement = match.replace( + IMAGE_MARKDOWN_REGEX, + `![${this.newCaption}|$2$3$4]($5)` + ); + this.appEvents.trigger("composer:replace-text", match, replacement); + this.imageCaptionPopup.showPopup = false; + } + + +} diff --git a/assets/javascripts/discourse/lib/utilities.js b/assets/javascripts/discourse/lib/utilities.js new file mode 100644 index 00000000..bd407327 --- /dev/null +++ b/assets/javascripts/discourse/lib/utilities.js @@ -0,0 +1,2 @@ +export const IMAGE_MARKDOWN_REGEX = + /!\[(.*?)\|(\d{1,4}x\d{1,4})(,\s*\d{1,3}%)?(.*?)\]\((upload:\/\/.*?)\)(?!(.*`))/g; diff --git a/assets/javascripts/discourse/services/image-caption-popup.js b/assets/javascripts/discourse/services/image-caption-popup.js new file mode 100644 index 00000000..35b0efc4 --- /dev/null +++ b/assets/javascripts/discourse/services/image-caption-popup.js @@ -0,0 +1,10 @@ +import { tracked } from "@glimmer/tracking"; +import Service from "@ember/service"; + +export default class ImageCaptionPopup extends Service { + @tracked showPopup = false; + @tracked imageIndex = null; + @tracked imageSrc = null; + @tracked newCaption = null; + @tracked loading = false; +} diff --git a/assets/javascripts/initializers/ai-image-caption.js b/assets/javascripts/initializers/ai-image-caption.js new file mode 100644 index 00000000..a99884c3 --- /dev/null +++ b/assets/javascripts/initializers/ai-image-caption.js @@ -0,0 +1,53 @@ +import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import { apiInitializer } from "discourse/lib/api"; +import I18n from "discourse-i18n"; + +export default apiInitializer("1.25.0", (api) => { + const buttonAttrs = { + label: I18n.t("discourse_ai.ai_helper.image_caption.button_label"), + icon: "discourse-sparkles", + class: "generate-caption", + }; + const imageCaptionPopup = api.container.lookup("service:imageCaptionPopup"); + + api.addComposerImageWrapperButton( + buttonAttrs.label, + buttonAttrs.class, + buttonAttrs.icon, + (event) => { + if (event.target.classList.contains("generate-caption")) { + const buttonWrapper = event.target.closest(".button-wrapper"); + const imageIndex = parseInt( + buttonWrapper.getAttribute("data-image-index"), + 10 + ); + const imageSrc = event.target + .closest(".image-wrapper") + .querySelector("img") + .getAttribute("src"); + + imageCaptionPopup.loading = true; + imageCaptionPopup.showPopup = !imageCaptionPopup.showPopup; + + ajax(`/discourse-ai/ai-helper/caption_image`, { + method: "POST", + data: { + image_url: imageSrc, + }, + }) + .then(({ caption }) => { + event.target.classList.add("disabled"); + imageCaptionPopup.imageSrc = imageSrc; + imageCaptionPopup.imageIndex = imageIndex; + imageCaptionPopup.newCaption = caption; + }) + .catch(popupAjaxError) + .finally(() => { + imageCaptionPopup.loading = false; + event.target.classList.remove("disabled"); + }); + } + } + ); +}); diff --git a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss index 93f219ea..90c4b0f5 100644 --- a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss +++ b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss @@ -457,3 +457,68 @@ } } } + +// AI Image Caption Feature: +.image-wrapper .button-wrapper { + .generate-caption { + background: var(--tertiary-100); + color: var(--tertiary-900); + font-weight: 600; + position: absolute; + top: -3.5rem; + left: 1rem; + padding: 0.5em; + transition: background 0.25s ease; + + .d-icon { + margin-right: 0.25rem; + } + + &:hover, + &:focus { + background: var(--tertiary-hover); + color: white; + cursor: pointer; + } + + &.disabled { + pointer-events: none; + cursor: not-allowed; + opacity: 0.8; + &:hover, + &:focus { + background: var(--tertiary-100); + color: var(--tertiary-900); + } + } + } +} + +.ai-caption-popup { + width: 400px; + right: unset; + bottom: 1.5rem; + top: unset; + + textarea { + width: 95%; + height: 100px; + } + + .actions { + display: flex; + align-items: center; + gap: 0.5rem; + + .credits { + font-size: var(--font-down-1); + margin-left: auto; + color: var(--tertiary); + } + } + + .spinner { + border-color: var(--tertiary-600); + border-right-color: var(--tertiary); + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index a2f1eb9d..c6bab5b4 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -186,6 +186,12 @@ en: title: "Suggested Thumbnails" select: "Select" selected: "Selected" + image_caption: + button_label: "Caption with AI" + generating: "Generating caption..." + credits: "Captioned by AI" + save_caption: "Save" + reviewables: model_used: "Model used:" accuracy: "Accuracy:" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 86288151..5eafe23a 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -59,6 +59,8 @@ en: ai_gemini_api_key: API key for Google Gemini API ai_vllm_endpoint: URL where the API is running for vLLM ai_vllm_api_key: API key for vLLM API + ai_llava_endpoint: URL where the API is running for llava + ai_llava_api_key: API key for llava API composer_ai_helper_enabled: "Enable the Composer's AI helper." ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." @@ -70,6 +72,7 @@ en: ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature" ai_helper_enabled_features: "Select the features to enable in the AI helper." post_ai_helper_allowed_groups: "User groups allowed to access AI Helper features in posts" + ai_helper_image_caption_model: "Select the model to use for generating image captions" ai_embeddings_enabled: "Enable the embeddings module." ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" diff --git a/config/routes.rb b/config/routes.rb index d91bc4d3..a360cd69 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do post "suggest_category" => "assistant#suggest_category" post "suggest_tags" => "assistant#suggest_tags" post "explain" => "assistant#explain" + post "caption_image" => "assistant#caption_image" end scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do diff --git a/config/settings.yml b/config/settings.yml index 5465ae46..44e8bf0e 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -165,6 +165,12 @@ discourse_ai: default: "" hidden: true ai_vllm_api_key: "" + ai_llava_endpoint: + default: "" + ai_llava_endpoint_srv: + default: "" + hidden: true + ai_llava_api_key: "" composer_ai_helper_enabled: default: false @@ -218,6 +224,13 @@ discourse_ai: choices: - "suggestions" - "context_menu" + ai_helper_image_caption_model: + default: "llava" + type: enum + choices: + - "llava" + - "open_ai:gpt-4-vision-preview" + ai_embeddings_enabled: default: false diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index 8664709f..fc508cec 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -91,6 +91,42 @@ module DiscourseAi end end + def generate_image_caption(image_url, user) + if SiteSetting.ai_helper_image_caption_model == "llava" + parameters = { + input: { + image: image_url, + top_p: 1, + max_tokens: 1024, + temperature: 0.2, + prompt: "Please describe this image in a single sentence", + }, + } + + ::DiscourseAi::Inference::Llava.perform!(parameters).dig(:output).join + else + prompt = + DiscourseAi::Completions::Prompt.new( + messages: [ + { + type: :user, + content: [ + { type: "text", text: "Describe this image in a single sentence" }, + { type: "image_url", image_url: image_url }, + ], + }, + ], + skip_validations: true, + ) + + DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate( + prompt, + user: Discourse.system_user, + max_tokens: 1024, + ) + end + end + private SANITIZE_REGEX_STR = diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index c13f7166..b98c9461 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -13,6 +13,7 @@ module DiscourseAi gpt-4-32k gpt-4-0125-preview gpt-4-turbo + gpt-4-vision-preview ].include?(model_name) end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 33382b0e..d0c2972d 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -15,6 +15,7 @@ module DiscourseAi gpt-4-32k gpt-4-0125-preview gpt-4-turbo + gpt-4-vision-preview ].include?(model_name) end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 5bae7500..07b5d9bd 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -41,7 +41,14 @@ module DiscourseAi Llama2-*-chat-hf Llama2-chat-hf ], - open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo], + open_ai: %w[ + gpt-3.5-turbo + gpt-4 + gpt-3.5-turbo-16k + gpt-4-32k + gpt-4-turbo + gpt-4-vision-preview + ], google: %w[gemini-pro], }.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? } end diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 5819055f..72f6b919 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -8,11 +8,12 @@ module DiscourseAi attr_reader :messages attr_accessor :tools - def initialize(system_message_text = nil, messages: [], tools: []) + def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) @messages = [] + @skip_validations = skip_validations if system_message_text system_message = { type: :system, content: system_message_text } @@ -41,6 +42,7 @@ module DiscourseAi private def validate_message(message) + return if @skip_validations valid_types = %i[system user model tool tool_call] if !valid_types.include?(message[:type]) raise ArgumentError, "message type must be one of #{valid_types}" @@ -55,6 +57,7 @@ module DiscourseAi end def validate_turn(last_turn, new_turn) + return if @skip_validations valid_types = %i[tool tool_call model user] raise INVALID_TURN if !valid_types.include?(new_turn[:type]) diff --git a/lib/inference/llava.rb b/lib/inference/llava.rb new file mode 100644 index 00000000..3ca8a341 --- /dev/null +++ b/lib/inference/llava.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module ::DiscourseAi + module Inference + class Llava + def self.perform!(content) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + body = content.to_json + + if SiteSetting.ai_llava_endpoint_srv.present? + service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_llava_endpoint_srv) + api_endpoint = "https://#{service.target}:#{service.port}" + else + api_endpoint = SiteSetting.ai_llava_endpoint + end + + headers["X-API-KEY"] = SiteSetting.ai_llava_api_key if SiteSetting.ai_llava_api_key.present? + + response = Faraday.post("#{api_endpoint}/predictions", body, headers) + + raise Net::HTTPBadResponse if ![200].include?(response.status) + + JSON.parse(response.body, symbolize_names: true) + end + + def self.configured? + SiteSetting.ai_llava_endpoint.present? || SiteSetting.ai_llava_endpoint_srv.present? + end + end + end +end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 8b262856..0e28715c 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -107,4 +107,46 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do end end end + + describe "#caption_image" do + let(:image_url) { "https://example.com/image.jpg" } + let(:caption) { "A picture of a cat sitting on a table" } + + context "when logged in as an allowed user" do + fab!(:user) + + before do + sign_in(user) + user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]] + SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1] + SiteSetting.ai_llava_endpoint = "https://example.com" + + stub_request(:post, "https://example.com/predictions").to_return( + status: 200, + body: { output: caption.gsub(" ", " |").split("|") }.to_json, + ) + end + + it "returns the suggested caption for the image" do + post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url } + + expect(response.status).to eq(200) + expect(response.parsed_body["caption"]).to eq(caption) + end + + it "returns a 502 error when the completion call fails" do + stub_request(:post, "https://example.com/predictions").to_return(status: 502) + + post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url } + + expect(response.status).to eq(502) + end + + it "returns a 400 error when the image_url is blank" do + post "/discourse-ai/ai-helper/caption_image" + + expect(response.status).to eq(400) + end + end + end end