FEATURE: AI image caption (#470)
This PR adds a new feature where you can generate captions for images in the composer using AI. --------- Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
parent
1f74a77e17
commit
a9b2d6a30a
|
@ -105,6 +105,22 @@ module DiscourseAi
|
|||
status: 502
|
||||
end
|
||||
|
||||
def caption_image
|
||||
image_url = params[:image_url]
|
||||
raise Discourse::InvalidParameters.new(:image_url) if !image_url
|
||||
|
||||
image = Upload.where(url: params[:image_url])
|
||||
|
||||
hijack do
|
||||
caption =
|
||||
DiscourseAi::AiHelper::Assistant.new.generate_image_caption(image_url, current_user)
|
||||
render json: { caption: caption }, status: 200
|
||||
end
|
||||
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed, Net::HTTPBadResponse
|
||||
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
|
||||
status: 502
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def get_text_param!
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { fn } from "@ember/helper";
|
||||
import { on } from "@ember/modifier";
|
||||
import { action } from "@ember/object";
|
||||
import { inject as service } from "@ember/service";
|
||||
import ConditionalLoadingSpinner from "discourse/components/conditional-loading-spinner";
|
||||
import DButton from "discourse/components/d-button";
|
||||
import autoFocus from "discourse/modifiers/auto-focus";
|
||||
import icon from "discourse-common/helpers/d-icon";
|
||||
import i18n from "discourse-common/helpers/i18n";
|
||||
import { IMAGE_MARKDOWN_REGEX } from "../../lib/utilities";
|
||||
|
||||
export default class AiImageCaptionContainer extends Component {
|
||||
@service imageCaptionPopup;
|
||||
@service appEvents;
|
||||
@service composer;
|
||||
@tracked newCaption = this.imageCaptionPopup.newCaption || "";
|
||||
|
||||
@action
|
||||
updateCaption(event) {
|
||||
event.preventDefault();
|
||||
this.newCaption = event.target.value;
|
||||
}
|
||||
|
||||
@action
|
||||
saveCaption() {
|
||||
const index = this.imageCaptionPopup.imageIndex;
|
||||
const matchingPlaceholder =
|
||||
this.composer.model.reply.match(IMAGE_MARKDOWN_REGEX);
|
||||
const match = matchingPlaceholder[index];
|
||||
const replacement = match.replace(
|
||||
IMAGE_MARKDOWN_REGEX,
|
||||
`![${this.newCaption}|$2$3$4]($5)`
|
||||
);
|
||||
this.appEvents.trigger("composer:replace-text", match, replacement);
|
||||
this.imageCaptionPopup.showPopup = false;
|
||||
}
|
||||
|
||||
<template>
|
||||
{{#if this.imageCaptionPopup.showPopup}}
|
||||
<div class="composer-popup education-message ai-caption-popup">
|
||||
<DButton
|
||||
@class="btn-transparent close"
|
||||
@title="close"
|
||||
@action={{fn (mut this.imageCaptionPopup.showPopup) false}}
|
||||
@icon="times"
|
||||
/>
|
||||
|
||||
<ConditionalLoadingSpinner
|
||||
@condition={{this.imageCaptionPopup.loading}}
|
||||
>
|
||||
<textarea
|
||||
{{on "input" this.updateCaption}}
|
||||
{{autoFocus}}
|
||||
>{{this.newCaption}}</textarea>
|
||||
</ConditionalLoadingSpinner>
|
||||
|
||||
<div class="actions">
|
||||
<DButton
|
||||
class="btn-primary"
|
||||
@label="discourse_ai.ai_helper.image_caption.save_caption"
|
||||
@icon="check"
|
||||
@action={{this.saveCaption}}
|
||||
/>
|
||||
<DButton
|
||||
class="btn-flat"
|
||||
@label="cancel"
|
||||
@action={{fn (mut this.imageCaptionPopup.showPopup) false}}
|
||||
/>
|
||||
|
||||
<span class="credits">
|
||||
{{icon "discourse-sparkles"}}
|
||||
{{i18n "discourse_ai.ai_helper.image_caption.credits"}}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
{{/if}}
|
||||
</template>
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
export const IMAGE_MARKDOWN_REGEX =
|
||||
/!\[(.*?)\|(\d{1,4}x\d{1,4})(,\s*\d{1,3}%)?(.*?)\]\((upload:\/\/.*?)\)(?!(.*`))/g;
|
|
@ -0,0 +1,10 @@
|
|||
import { tracked } from "@glimmer/tracking";
|
||||
import Service from "@ember/service";
|
||||
|
||||
export default class ImageCaptionPopup extends Service {
|
||||
@tracked showPopup = false;
|
||||
@tracked imageIndex = null;
|
||||
@tracked imageSrc = null;
|
||||
@tracked newCaption = null;
|
||||
@tracked loading = false;
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
import { ajax } from "discourse/lib/ajax";
|
||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||
import { apiInitializer } from "discourse/lib/api";
|
||||
import I18n from "discourse-i18n";
|
||||
|
||||
export default apiInitializer("1.25.0", (api) => {
|
||||
const buttonAttrs = {
|
||||
label: I18n.t("discourse_ai.ai_helper.image_caption.button_label"),
|
||||
icon: "discourse-sparkles",
|
||||
class: "generate-caption",
|
||||
};
|
||||
const imageCaptionPopup = api.container.lookup("service:imageCaptionPopup");
|
||||
|
||||
api.addComposerImageWrapperButton(
|
||||
buttonAttrs.label,
|
||||
buttonAttrs.class,
|
||||
buttonAttrs.icon,
|
||||
(event) => {
|
||||
if (event.target.classList.contains("generate-caption")) {
|
||||
const buttonWrapper = event.target.closest(".button-wrapper");
|
||||
const imageIndex = parseInt(
|
||||
buttonWrapper.getAttribute("data-image-index"),
|
||||
10
|
||||
);
|
||||
const imageSrc = event.target
|
||||
.closest(".image-wrapper")
|
||||
.querySelector("img")
|
||||
.getAttribute("src");
|
||||
|
||||
imageCaptionPopup.loading = true;
|
||||
imageCaptionPopup.showPopup = !imageCaptionPopup.showPopup;
|
||||
|
||||
ajax(`/discourse-ai/ai-helper/caption_image`, {
|
||||
method: "POST",
|
||||
data: {
|
||||
image_url: imageSrc,
|
||||
},
|
||||
})
|
||||
.then(({ caption }) => {
|
||||
event.target.classList.add("disabled");
|
||||
imageCaptionPopup.imageSrc = imageSrc;
|
||||
imageCaptionPopup.imageIndex = imageIndex;
|
||||
imageCaptionPopup.newCaption = caption;
|
||||
})
|
||||
.catch(popupAjaxError)
|
||||
.finally(() => {
|
||||
imageCaptionPopup.loading = false;
|
||||
event.target.classList.remove("disabled");
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
|
@ -457,3 +457,68 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AI Image Caption Feature:
|
||||
.image-wrapper .button-wrapper {
|
||||
.generate-caption {
|
||||
background: var(--tertiary-100);
|
||||
color: var(--tertiary-900);
|
||||
font-weight: 600;
|
||||
position: absolute;
|
||||
top: -3.5rem;
|
||||
left: 1rem;
|
||||
padding: 0.5em;
|
||||
transition: background 0.25s ease;
|
||||
|
||||
.d-icon {
|
||||
margin-right: 0.25rem;
|
||||
}
|
||||
|
||||
&:hover,
|
||||
&:focus {
|
||||
background: var(--tertiary-hover);
|
||||
color: white;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
&.disabled {
|
||||
pointer-events: none;
|
||||
cursor: not-allowed;
|
||||
opacity: 0.8;
|
||||
&:hover,
|
||||
&:focus {
|
||||
background: var(--tertiary-100);
|
||||
color: var(--tertiary-900);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.ai-caption-popup {
|
||||
width: 400px;
|
||||
right: unset;
|
||||
bottom: 1.5rem;
|
||||
top: unset;
|
||||
|
||||
textarea {
|
||||
width: 95%;
|
||||
height: 100px;
|
||||
}
|
||||
|
||||
.actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
|
||||
.credits {
|
||||
font-size: var(--font-down-1);
|
||||
margin-left: auto;
|
||||
color: var(--tertiary);
|
||||
}
|
||||
}
|
||||
|
||||
.spinner {
|
||||
border-color: var(--tertiary-600);
|
||||
border-right-color: var(--tertiary);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -186,6 +186,12 @@ en:
|
|||
title: "Suggested Thumbnails"
|
||||
select: "Select"
|
||||
selected: "Selected"
|
||||
image_caption:
|
||||
button_label: "Caption with AI"
|
||||
generating: "Generating caption..."
|
||||
credits: "Captioned by AI"
|
||||
save_caption: "Save"
|
||||
|
||||
reviewables:
|
||||
model_used: "Model used:"
|
||||
accuracy: "Accuracy:"
|
||||
|
|
|
@ -59,6 +59,8 @@ en:
|
|||
ai_gemini_api_key: API key for Google Gemini API
|
||||
ai_vllm_endpoint: URL where the API is running for vLLM
|
||||
ai_vllm_api_key: API key for vLLM API
|
||||
ai_llava_endpoint: URL where the API is running for llava
|
||||
ai_llava_api_key: API key for llava API
|
||||
|
||||
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
||||
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||
|
@ -70,6 +72,7 @@ en:
|
|||
ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature"
|
||||
ai_helper_enabled_features: "Select the features to enable in the AI helper."
|
||||
post_ai_helper_allowed_groups: "User groups allowed to access AI Helper features in posts"
|
||||
ai_helper_image_caption_model: "Select the model to use for generating image captions"
|
||||
|
||||
ai_embeddings_enabled: "Enable the embeddings module."
|
||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||
|
|
|
@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do
|
|||
post "suggest_category" => "assistant#suggest_category"
|
||||
post "suggest_tags" => "assistant#suggest_tags"
|
||||
post "explain" => "assistant#explain"
|
||||
post "caption_image" => "assistant#caption_image"
|
||||
end
|
||||
|
||||
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
||||
|
|
|
@ -165,6 +165,12 @@ discourse_ai:
|
|||
default: ""
|
||||
hidden: true
|
||||
ai_vllm_api_key: ""
|
||||
ai_llava_endpoint:
|
||||
default: ""
|
||||
ai_llava_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_llava_api_key: ""
|
||||
|
||||
composer_ai_helper_enabled:
|
||||
default: false
|
||||
|
@ -218,6 +224,13 @@ discourse_ai:
|
|||
choices:
|
||||
- "suggestions"
|
||||
- "context_menu"
|
||||
ai_helper_image_caption_model:
|
||||
default: "llava"
|
||||
type: enum
|
||||
choices:
|
||||
- "llava"
|
||||
- "open_ai:gpt-4-vision-preview"
|
||||
|
||||
|
||||
ai_embeddings_enabled:
|
||||
default: false
|
||||
|
|
|
@ -91,6 +91,42 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def generate_image_caption(image_url, user)
|
||||
if SiteSetting.ai_helper_image_caption_model == "llava"
|
||||
parameters = {
|
||||
input: {
|
||||
image: image_url,
|
||||
top_p: 1,
|
||||
max_tokens: 1024,
|
||||
temperature: 0.2,
|
||||
prompt: "Please describe this image in a single sentence",
|
||||
},
|
||||
}
|
||||
|
||||
::DiscourseAi::Inference::Llava.perform!(parameters).dig(:output).join
|
||||
else
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
messages: [
|
||||
{
|
||||
type: :user,
|
||||
content: [
|
||||
{ type: "text", text: "Describe this image in a single sentence" },
|
||||
{ type: "image_url", image_url: image_url },
|
||||
],
|
||||
},
|
||||
],
|
||||
skip_validations: true,
|
||||
)
|
||||
|
||||
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
|
||||
prompt,
|
||||
user: Discourse.system_user,
|
||||
max_tokens: 1024,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
SANITIZE_REGEX_STR =
|
||||
|
|
|
@ -13,6 +13,7 @@ module DiscourseAi
|
|||
gpt-4-32k
|
||||
gpt-4-0125-preview
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
].include?(model_name)
|
||||
end
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ module DiscourseAi
|
|||
gpt-4-32k
|
||||
gpt-4-0125-preview
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
].include?(model_name)
|
||||
end
|
||||
|
||||
|
|
|
@ -41,7 +41,14 @@ module DiscourseAi
|
|||
Llama2-*-chat-hf
|
||||
Llama2-chat-hf
|
||||
],
|
||||
open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo],
|
||||
open_ai: %w[
|
||||
gpt-3.5-turbo
|
||||
gpt-4
|
||||
gpt-3.5-turbo-16k
|
||||
gpt-4-32k
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
],
|
||||
google: %w[gemini-pro],
|
||||
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
|
||||
end
|
||||
|
|
|
@ -8,11 +8,12 @@ module DiscourseAi
|
|||
attr_reader :messages
|
||||
attr_accessor :tools
|
||||
|
||||
def initialize(system_message_text = nil, messages: [], tools: [])
|
||||
def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
|
||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||
|
||||
@messages = []
|
||||
@skip_validations = skip_validations
|
||||
|
||||
if system_message_text
|
||||
system_message = { type: :system, content: system_message_text }
|
||||
|
@ -41,6 +42,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def validate_message(message)
|
||||
return if @skip_validations
|
||||
valid_types = %i[system user model tool tool_call]
|
||||
if !valid_types.include?(message[:type])
|
||||
raise ArgumentError, "message type must be one of #{valid_types}"
|
||||
|
@ -55,6 +57,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def validate_turn(last_turn, new_turn)
|
||||
return if @skip_validations
|
||||
valid_types = %i[tool tool_call model user]
|
||||
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class Llava
|
||||
def self.perform!(content)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
body = content.to_json
|
||||
|
||||
if SiteSetting.ai_llava_endpoint_srv.present?
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_llava_endpoint_srv)
|
||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||
else
|
||||
api_endpoint = SiteSetting.ai_llava_endpoint
|
||||
end
|
||||
|
||||
headers["X-API-KEY"] = SiteSetting.ai_llava_api_key if SiteSetting.ai_llava_api_key.present?
|
||||
|
||||
response = Faraday.post("#{api_endpoint}/predictions", body, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def self.configured?
|
||||
SiteSetting.ai_llava_endpoint.present? || SiteSetting.ai_llava_endpoint_srv.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -107,4 +107,46 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#caption_image" do
|
||||
let(:image_url) { "https://example.com/image.jpg" }
|
||||
let(:caption) { "A picture of a cat sitting on a table" }
|
||||
|
||||
context "when logged in as an allowed user" do
|
||||
fab!(:user)
|
||||
|
||||
before do
|
||||
sign_in(user)
|
||||
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
||||
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
|
||||
SiteSetting.ai_llava_endpoint = "https://example.com"
|
||||
|
||||
stub_request(:post, "https://example.com/predictions").to_return(
|
||||
status: 200,
|
||||
body: { output: caption.gsub(" ", " |").split("|") }.to_json,
|
||||
)
|
||||
end
|
||||
|
||||
it "returns the suggested caption for the image" do
|
||||
post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url }
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["caption"]).to eq(caption)
|
||||
end
|
||||
|
||||
it "returns a 502 error when the completion call fails" do
|
||||
stub_request(:post, "https://example.com/predictions").to_return(status: 502)
|
||||
|
||||
post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url }
|
||||
|
||||
expect(response.status).to eq(502)
|
||||
end
|
||||
|
||||
it "returns a 400 error when the image_url is blank" do
|
||||
post "/discourse-ai/ai-helper/caption_image"
|
||||
|
||||
expect(response.status).to eq(400)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue