FEATURE: AI image caption (#470)

This PR adds a new feature where you can generate captions for images in the composer using AI.

---------

Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
Keegan George 2024-02-19 09:56:28 -08:00 committed by GitHub
parent 1f74a77e17
commit a9b2d6a30a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 372 additions and 2 deletions

View File

@ -105,6 +105,22 @@ module DiscourseAi
status: 502
end
def caption_image
image_url = params[:image_url]
raise Discourse::InvalidParameters.new(:image_url) if !image_url
image = Upload.where(url: params[:image_url])
hijack do
caption =
DiscourseAi::AiHelper::Assistant.new.generate_image_caption(image_url, current_user)
render json: { caption: caption }, status: 200
end
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed, Net::HTTPBadResponse
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
end
private
def get_text_param!

View File

@ -0,0 +1,80 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { fn } from "@ember/helper";
import { on } from "@ember/modifier";
import { action } from "@ember/object";
import { inject as service } from "@ember/service";
import ConditionalLoadingSpinner from "discourse/components/conditional-loading-spinner";
import DButton from "discourse/components/d-button";
import autoFocus from "discourse/modifiers/auto-focus";
import icon from "discourse-common/helpers/d-icon";
import i18n from "discourse-common/helpers/i18n";
import { IMAGE_MARKDOWN_REGEX } from "../../lib/utilities";
export default class AiImageCaptionContainer extends Component {
@service imageCaptionPopup;
@service appEvents;
@service composer;
@tracked newCaption = this.imageCaptionPopup.newCaption || "";
@action
updateCaption(event) {
event.preventDefault();
this.newCaption = event.target.value;
}
@action
saveCaption() {
const index = this.imageCaptionPopup.imageIndex;
const matchingPlaceholder =
this.composer.model.reply.match(IMAGE_MARKDOWN_REGEX);
const match = matchingPlaceholder[index];
const replacement = match.replace(
IMAGE_MARKDOWN_REGEX,
`![${this.newCaption}|$2$3$4]($5)`
);
this.appEvents.trigger("composer:replace-text", match, replacement);
this.imageCaptionPopup.showPopup = false;
}
<template>
{{#if this.imageCaptionPopup.showPopup}}
<div class="composer-popup education-message ai-caption-popup">
<DButton
@class="btn-transparent close"
@title="close"
@action={{fn (mut this.imageCaptionPopup.showPopup) false}}
@icon="times"
/>
<ConditionalLoadingSpinner
@condition={{this.imageCaptionPopup.loading}}
>
<textarea
{{on "input" this.updateCaption}}
{{autoFocus}}
>{{this.newCaption}}</textarea>
</ConditionalLoadingSpinner>
<div class="actions">
<DButton
class="btn-primary"
@label="discourse_ai.ai_helper.image_caption.save_caption"
@icon="check"
@action={{this.saveCaption}}
/>
<DButton
class="btn-flat"
@label="cancel"
@action={{fn (mut this.imageCaptionPopup.showPopup) false}}
/>
<span class="credits">
{{icon "discourse-sparkles"}}
{{i18n "discourse_ai.ai_helper.image_caption.credits"}}
</span>
</div>
</div>
{{/if}}
</template>
}

View File

@ -0,0 +1,2 @@
export const IMAGE_MARKDOWN_REGEX =
/!\[(.*?)\|(\d{1,4}x\d{1,4})(,\s*\d{1,3}%)?(.*?)\]\((upload:\/\/.*?)\)(?!(.*`))/g;

View File

@ -0,0 +1,10 @@
import { tracked } from "@glimmer/tracking";
import Service from "@ember/service";
export default class ImageCaptionPopup extends Service {
@tracked showPopup = false;
@tracked imageIndex = null;
@tracked imageSrc = null;
@tracked newCaption = null;
@tracked loading = false;
}

View File

@ -0,0 +1,53 @@
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { apiInitializer } from "discourse/lib/api";
import I18n from "discourse-i18n";
export default apiInitializer("1.25.0", (api) => {
const buttonAttrs = {
label: I18n.t("discourse_ai.ai_helper.image_caption.button_label"),
icon: "discourse-sparkles",
class: "generate-caption",
};
const imageCaptionPopup = api.container.lookup("service:imageCaptionPopup");
api.addComposerImageWrapperButton(
buttonAttrs.label,
buttonAttrs.class,
buttonAttrs.icon,
(event) => {
if (event.target.classList.contains("generate-caption")) {
const buttonWrapper = event.target.closest(".button-wrapper");
const imageIndex = parseInt(
buttonWrapper.getAttribute("data-image-index"),
10
);
const imageSrc = event.target
.closest(".image-wrapper")
.querySelector("img")
.getAttribute("src");
imageCaptionPopup.loading = true;
imageCaptionPopup.showPopup = !imageCaptionPopup.showPopup;
ajax(`/discourse-ai/ai-helper/caption_image`, {
method: "POST",
data: {
image_url: imageSrc,
},
})
.then(({ caption }) => {
event.target.classList.add("disabled");
imageCaptionPopup.imageSrc = imageSrc;
imageCaptionPopup.imageIndex = imageIndex;
imageCaptionPopup.newCaption = caption;
})
.catch(popupAjaxError)
.finally(() => {
imageCaptionPopup.loading = false;
event.target.classList.remove("disabled");
});
}
}
);
});

View File

@ -457,3 +457,68 @@
}
}
}
// AI Image Caption Feature:
.image-wrapper .button-wrapper {
.generate-caption {
background: var(--tertiary-100);
color: var(--tertiary-900);
font-weight: 600;
position: absolute;
top: -3.5rem;
left: 1rem;
padding: 0.5em;
transition: background 0.25s ease;
.d-icon {
margin-right: 0.25rem;
}
&:hover,
&:focus {
background: var(--tertiary-hover);
color: white;
cursor: pointer;
}
&.disabled {
pointer-events: none;
cursor: not-allowed;
opacity: 0.8;
&:hover,
&:focus {
background: var(--tertiary-100);
color: var(--tertiary-900);
}
}
}
}
.ai-caption-popup {
width: 400px;
right: unset;
bottom: 1.5rem;
top: unset;
textarea {
width: 95%;
height: 100px;
}
.actions {
display: flex;
align-items: center;
gap: 0.5rem;
.credits {
font-size: var(--font-down-1);
margin-left: auto;
color: var(--tertiary);
}
}
.spinner {
border-color: var(--tertiary-600);
border-right-color: var(--tertiary);
}
}

View File

@ -186,6 +186,12 @@ en:
title: "Suggested Thumbnails"
select: "Select"
selected: "Selected"
image_caption:
button_label: "Caption with AI"
generating: "Generating caption..."
credits: "Captioned by AI"
save_caption: "Save"
reviewables:
model_used: "Model used:"
accuracy: "Accuracy:"

View File

@ -59,6 +59,8 @@ en:
ai_gemini_api_key: API key for Google Gemini API
ai_vllm_endpoint: URL where the API is running for vLLM
ai_vllm_api_key: API key for vLLM API
ai_llava_endpoint: URL where the API is running for llava
ai_llava_api_key: API key for llava API
composer_ai_helper_enabled: "Enable the Composer's AI helper."
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
@ -70,6 +72,7 @@ en:
ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature"
ai_helper_enabled_features: "Select the features to enable in the AI helper."
post_ai_helper_allowed_groups: "User groups allowed to access AI Helper features in posts"
ai_helper_image_caption_model: "Select the model to use for generating image captions"
ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"

View File

@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do
post "suggest_category" => "assistant#suggest_category"
post "suggest_tags" => "assistant#suggest_tags"
post "explain" => "assistant#explain"
post "caption_image" => "assistant#caption_image"
end
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do

View File

@ -165,6 +165,12 @@ discourse_ai:
default: ""
hidden: true
ai_vllm_api_key: ""
ai_llava_endpoint:
default: ""
ai_llava_endpoint_srv:
default: ""
hidden: true
ai_llava_api_key: ""
composer_ai_helper_enabled:
default: false
@ -218,6 +224,13 @@ discourse_ai:
choices:
- "suggestions"
- "context_menu"
ai_helper_image_caption_model:
default: "llava"
type: enum
choices:
- "llava"
- "open_ai:gpt-4-vision-preview"
ai_embeddings_enabled:
default: false

View File

@ -91,6 +91,42 @@ module DiscourseAi
end
end
def generate_image_caption(image_url, user)
if SiteSetting.ai_helper_image_caption_model == "llava"
parameters = {
input: {
image: image_url,
top_p: 1,
max_tokens: 1024,
temperature: 0.2,
prompt: "Please describe this image in a single sentence",
},
}
::DiscourseAi::Inference::Llava.perform!(parameters).dig(:output).join
else
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [
{
type: :user,
content: [
{ type: "text", text: "Describe this image in a single sentence" },
{ type: "image_url", image_url: image_url },
],
},
],
skip_validations: true,
)
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
prompt,
user: Discourse.system_user,
max_tokens: 1024,
)
end
end
private
SANITIZE_REGEX_STR =

View File

@ -13,6 +13,7 @@ module DiscourseAi
gpt-4-32k
gpt-4-0125-preview
gpt-4-turbo
gpt-4-vision-preview
].include?(model_name)
end

View File

@ -15,6 +15,7 @@ module DiscourseAi
gpt-4-32k
gpt-4-0125-preview
gpt-4-turbo
gpt-4-vision-preview
].include?(model_name)
end

View File

@ -41,7 +41,14 @@ module DiscourseAi
Llama2-*-chat-hf
Llama2-chat-hf
],
open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo],
open_ai: %w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
],
google: %w[gemini-pro],
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
end

View File

@ -8,11 +8,12 @@ module DiscourseAi
attr_reader :messages
attr_accessor :tools
def initialize(system_message_text = nil, messages: [], tools: [])
def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@messages = []
@skip_validations = skip_validations
if system_message_text
system_message = { type: :system, content: system_message_text }
@ -41,6 +42,7 @@ module DiscourseAi
private
def validate_message(message)
return if @skip_validations
valid_types = %i[system user model tool tool_call]
if !valid_types.include?(message[:type])
raise ArgumentError, "message type must be one of #{valid_types}"
@ -55,6 +57,7 @@ module DiscourseAi
end
def validate_turn(last_turn, new_turn)
return if @skip_validations
valid_types = %i[tool tool_call model user]
raise INVALID_TURN if !valid_types.include?(new_turn[:type])

31
lib/inference/llava.rb Normal file
View File

@ -0,0 +1,31 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class Llava
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = content.to_json
if SiteSetting.ai_llava_endpoint_srv.present?
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_llava_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_llava_endpoint
end
headers["X-API-KEY"] = SiteSetting.ai_llava_api_key if SiteSetting.ai_llava_api_key.present?
response = Faraday.post("#{api_endpoint}/predictions", body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
def self.configured?
SiteSetting.ai_llava_endpoint.present? || SiteSetting.ai_llava_endpoint_srv.present?
end
end
end
end

View File

@ -107,4 +107,46 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
end
end
end
describe "#caption_image" do
let(:image_url) { "https://example.com/image.jpg" }
let(:caption) { "A picture of a cat sitting on a table" }
context "when logged in as an allowed user" do
fab!(:user)
before do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_llava_endpoint = "https://example.com"
stub_request(:post, "https://example.com/predictions").to_return(
status: 200,
body: { output: caption.gsub(" ", " |").split("|") }.to_json,
)
end
it "returns the suggested caption for the image" do
post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url }
expect(response.status).to eq(200)
expect(response.parsed_body["caption"]).to eq(caption)
end
it "returns a 502 error when the completion call fails" do
stub_request(:post, "https://example.com/predictions").to_return(status: 502)
post "/discourse-ai/ai-helper/caption_image", params: { image_url: image_url }
expect(response.status).to eq(502)
end
it "returns a 400 error when the image_url is blank" do
post "/discourse-ai/ai-helper/caption_image"
expect(response.status).to eq(400)
end
end
end
end