FEATURE: Track if a model can do vision in the llm_models table (#725)

* FEATURE: Track if a model can do vision in the llm_models table

* Data migration
This commit is contained in:
Roman Rizzi 2024-07-24 16:29:47 -03:00 committed by GitHub
parent 06e239321b
commit 5c196bca89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 289 additions and 263 deletions

View File

@ -106,6 +106,7 @@ module DiscourseAi
:max_prompt_tokens, :max_prompt_tokens,
:api_key, :api_key,
:enabled_chat_bot, :enabled_chat_bot,
:vision_enabled,
) )
provider = updating ? updating.provider : permitted[:provider] provider = updating ? updating.provider : permitted[:provider]

View File

@ -124,4 +124,6 @@ end
# api_key :string # api_key :string
# user_id :integer # user_id :integer
# enabled_chat_bot :boolean default(FALSE), not null # enabled_chat_bot :boolean default(FALSE), not null
# provider_params :jsonb
# vision_enabled :boolean default(FALSE), not null
# #

View File

@ -13,7 +13,8 @@ class LlmModelSerializer < ApplicationSerializer
:url, :url,
:enabled_chat_bot, :enabled_chat_bot,
:shadowed_by_srv, :shadowed_by_srv,
:provider_params :provider_params,
:vision_enabled
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object

View File

@ -13,7 +13,8 @@ export default class AiLlm extends RestModel {
"url", "url",
"api_key", "api_key",
"enabled_chat_bot", "enabled_chat_bot",
"provider_params" "provider_params",
"vision_enabled"
); );
} }

View File

@ -267,6 +267,14 @@ export default class AiLlmEditorForm extends Component {
@content={{I18n.t "discourse_ai.llms.hints.max_prompt_tokens"}} @content={{I18n.t "discourse_ai.llms.hints.max_prompt_tokens"}}
/> />
</div> </div>
<div class="control-group ai-llm-editor__vision-enabled">
<Input @type="checkbox" @checked={{@model.vision_enabled}} />
<label>{{I18n.t "discourse_ai.llms.vision_enabled"}}</label>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.llms.hints.vision_enabled"}}
/>
</div>
<div class="control-group"> <div class="control-group">
<DToggleSwitch <DToggleSwitch
class="ai-llm-editor__enabled-chat-bot" class="ai-llm-editor__enabled-chat-bot"

View File

@ -41,4 +41,9 @@
display: flex; display: flex;
align-items: center; align-items: center;
} }
&__vision-enabled {
display: flex;
align-items: flex-start;
}
} }

View File

@ -228,6 +228,7 @@ en:
url: "URL of the service hosting the model" url: "URL of the service hosting the model"
api_key: "API Key of the service hosting the model" api_key: "API Key of the service hosting the model"
enabled_chat_bot: "Allow AI Bot" enabled_chat_bot: "Allow AI Bot"
vision_enabled: "Vision enabled"
ai_bot_user: "AI Bot User" ai_bot_user: "AI Bot User"
save: "Save" save: "Save"
edit: "Edit" edit: "Edit"
@ -252,6 +253,7 @@ en:
hints: hints:
max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window." max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window."
name: "We include this in the API call to specify which model we'll use." name: "We include this in the API call to specify which model we'll use."
vision_enabled: "If enabled, the AI will attempt to understand images. It depends on the model being used supporting vision. Supported by latest models from Anthropic, Google, and OpenAI."
providers: providers:
aws_bedrock: "AWS Bedrock" aws_bedrock: "AWS Bedrock"

View File

@ -189,10 +189,13 @@ discourse_ai:
ai_vllm_api_key: "" ai_vllm_api_key: ""
ai_llava_endpoint: ai_llava_endpoint:
default: "" default: ""
hidden: true
ai_llava_endpoint_srv: ai_llava_endpoint_srv:
default: "" default: ""
hidden: true hidden: true
ai_llava_api_key: "" ai_llava_api_key:
default: ""
hidden: true
ai_strict_token_counting: ai_strict_token_counting:
default: false default: false
hidden: true hidden: true
@ -254,7 +257,7 @@ discourse_ai:
- "context_menu" - "context_menu"
- "image_caption" - "image_caption"
ai_helper_image_caption_model: ai_helper_image_caption_model:
default: "llava" default: ""
type: enum type: enum
enum: "DiscourseAi::Configuration::LlmVisionEnumerator" enum: "DiscourseAi::Configuration::LlmVisionEnumerator"
ai_auto_image_caption_allowed_groups: ai_auto_image_caption_allowed_groups:

View File

@ -0,0 +1,6 @@
# frozen_string_literal: true
class LlmModelVisionEnabled < ActiveRecord::Migration[7.1]
def change
add_column :llm_models, :vision_enabled, :boolean, default: false, null: false
end
end

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
class MigrateVisionLlms < ActiveRecord::Migration[7.1]
def up
vision_models = %w[
claude-3-sonnet
claude-3-opus
claude-3-haiku
gpt-4-vision-preview
gpt-4-turbo
gpt-4o
gemini-1.5-pro
gemini-1.5-flash
]
DB.exec(<<~SQL, names: vision_models)
UPDATE llm_models
SET vision_enabled = true
WHERE name IN (:names)
SQL
current_value =
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: "ai_helper_image_caption_model",
).first
if current_value && current_value != "llava"
llm_model =
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: current_value).first
if llm_model
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model
UPDATE site_settings
SET value = :new
WHERE name = 'ai_helper_image_caption_model'
SQL
end
end
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -143,47 +143,26 @@ module DiscourseAi
end end
def generate_image_caption(upload, user) def generate_image_caption(upload, user)
if SiteSetting.ai_helper_image_caption_model == "llava" prompt =
image_base64 = DiscourseAi::Completions::Prompt.new(
DiscourseAi::Completions::UploadEncoder.encode( "You are a bot specializing in image captioning.",
upload_ids: [upload.id], messages: [
max_pixels: 1_048_576, {
).first[ type: :user,
:base64 content:
] "Describe this image in a single sentence#{custom_locale_instructions(user)}",
parameters = { upload_ids: [upload.id],
input: { },
image: "data:image/#{upload.extension};base64, #{image_base64}", ],
top_p: 1, skip_validations: true,
max_tokens: 1024,
temperature: 0.2,
prompt: "Please describe this image in a single sentence",
},
}
::DiscourseAi::Inference::Llava.perform!(parameters).dig(:output).join
else
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot specializing in image captioning.",
messages: [
{
type: :user,
content:
"Describe this image in a single sentence#{custom_locale_instructions(user)}",
upload_ids: [upload.id],
},
],
skip_validations: true,
)
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
prompt,
user: user,
max_tokens: 1024,
feature_name: "image_caption",
) )
end
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
prompt,
user: user,
max_tokens: 1024,
feature_name: "image_caption",
)
end end
private private

View File

@ -78,33 +78,25 @@ module DiscourseAi
end end
end end
user_message[:content] = inline_images(user_message[:content], msg) user_message[:content] = inline_images(user_message[:content], msg) if vision_support?
user_message user_message
end end
def inline_images(content, message) def inline_images(content, message)
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" || encoded_uploads = prompt.encoded_uploads(message)
model_name == "gpt-4o" return content if encoded_uploads.blank?
content = message[:content]
encoded_uploads = prompt.encoded_uploads(message)
if encoded_uploads.present?
new_content = []
new_content.concat(
encoded_uploads.map do |details|
{
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end,
)
new_content << { type: "text", text: content }
content = new_content
end
end
content content_w_imgs =
encoded_uploads.reduce([]) do |memo, details|
memo << {
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end
content_w_imgs << { type: "text", text: message[:content] }
end end
def per_message_overhead def per_message_overhead

View File

@ -109,34 +109,28 @@ module DiscourseAi
content = +"" content = +""
content << "#{msg[:id]}: " if msg[:id] content << "#{msg[:id]}: " if msg[:id]
content << msg[:content] content << msg[:content]
content = inline_images(content, msg) content = inline_images(content, msg) if vision_support?
{ role: "user", content: content } { role: "user", content: content }
end end
def inline_images(content, message) def inline_images(content, message)
if model_name.include?("claude-3") encoded_uploads = prompt.encoded_uploads(message)
encoded_uploads = prompt.encoded_uploads(message) return content if encoded_uploads.blank?
if encoded_uploads.present?
new_content = []
new_content.concat(
encoded_uploads.map do |details|
{
source: {
type: "base64",
data: details[:base64],
media_type: details[:mime_type],
},
type: "image",
}
end,
)
new_content << { type: "text", text: content }
content = new_content
end
end
content content_w_imgs =
encoded_uploads.reduce([]) do |memo, details|
memo << {
source: {
type: "base64",
data: details[:base64],
media_type: details[:mime_type],
},
type: "image",
}
end
content_w_imgs << { type: "text", text: content }
end end
end end
end end

View File

@ -56,6 +56,10 @@ module DiscourseAi
false false
end end
def vision_support?
llm_model&.vision_enabled?
end
def tools def tools
@tools ||= tools_dialect.translated_tools @tools ||= tools_dialect.translated_tools
end end

View File

@ -114,6 +114,8 @@ module DiscourseAi
if beta_api? if beta_api?
# support new format with multiple parts # support new format with multiple parts
result = { role: "user", parts: [{ text: msg[:content] }] } result = { role: "user", parts: [{ text: msg[:content] }] }
return result unless vision_support?
upload_parts = uploaded_parts(msg) upload_parts = uploaded_parts(msg)
result[:parts].concat(upload_parts) if upload_parts.present? result[:parts].concat(upload_parts) if upload_parts.present?
result result

View File

@ -47,7 +47,28 @@ module DiscourseAi
content << "#{msg[:id]}: " if msg[:id] content << "#{msg[:id]}: " if msg[:id]
content << msg[:content] content << msg[:content]
{ role: "user", content: content } message = { role: "user", content: content }
message[:content] = inline_images(message[:content], msg) if vision_support?
message
end
def inline_images(content, message)
encoded_uploads = prompt.encoded_uploads(message)
return content if encoded_uploads.blank?
content_w_imgs =
encoded_uploads.reduce([]) do |memo, details|
memo << {
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end
content_w_imgs << { type: "text", text: message[:content] }
end end
end end
end end

View File

@ -35,6 +35,8 @@ module DiscourseAi
"The number of completions you requested exceed the number of canned responses" "The number of completions you requested exceed the number of canned responses"
end end
raise response if response.is_a?(StandardError)
@completions += 1 @completions += 1
if block_given? if block_given?
cancelled = false cancelled = false

View File

@ -89,15 +89,6 @@ module DiscourseAi
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name) DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
end end
def vision_models_by_provider
@vision_models_by_provider ||= {
aws_bedrock: %w[claude-3-sonnet claude-3-opus claude-3-haiku],
anthropic: %w[claude-3-sonnet claude-3-opus claude-3-haiku],
open_ai: %w[gpt-4-vision-preview gpt-4-turbo gpt-4o],
google: %w[gemini-1.5-pro gemini-1.5-flash],
}
end
def models_by_provider def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them. # However, since they use the same URL/key settings, there's no reason to duplicate them.

View File

@ -10,24 +10,15 @@ module DiscourseAi
end end
def self.values def self.values
begin values = DB.query_hash(<<~SQL).map(&:symbolize_keys)
result = SELECT display_name AS name, id AS value
DiscourseAi::Completions::Llm.vision_models_by_provider.flat_map do |provider, models| FROM llm_models
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s) WHERE vision_enabled
SQL
models.map do |model_name| values.each { |value_h| value_h[:value] = "custom:#{value_h[:value]}" }
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
end
end
result << { name: "Llava", value: "llava" } values
result
# TODO add support for LlmModel as well
# LlmModel.all.each do |model|
# llm_models << { name: model.display_name, value: "custom:#{model.id}" }
# end
end
end end
end end
end end

View File

@ -1,31 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class Llava
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = content.to_json
if SiteSetting.ai_llava_endpoint_srv.present?
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_llava_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_llava_endpoint
end
headers["X-API-KEY"] = SiteSetting.ai_llava_api_key if SiteSetting.ai_llava_api_key.present?
response = Faraday.post("#{api_endpoint}/predictions", body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
def self.configured?
SiteSetting.ai_llava_endpoint.present? || SiteSetting.ai_llava_endpoint_srv.present?
end
end
end
end

View File

@ -2,7 +2,18 @@
require_relative "endpoint_compliance" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") } let(:url) { "https://api.anthropic.com/v1/messages" }
fab!(:model) do
Fabricate(
:llm_model,
url: "https://api.anthropic.com/v1/messages",
name: "claude-3-opus",
provider: "anthropic",
api_key: "123",
vision_enabled: true,
)
end
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do let(:upload100x100) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
@ -45,8 +56,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
prompt_with_tools prompt_with_tools
end end
before { SiteSetting.ai_anthropic_api_key = "123" }
it "does not eat spaces with tool calls" do it "does not eat spaces with tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
body = <<~STRING body = <<~STRING
@ -108,10 +117,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = +"" result = +""
body = body.scan(/.*\n/) body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do EndpointMock.with_chunk_array_support do
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return( stub_request(:post, url).to_return(status: 200, body: body)
status: 200,
body: body,
)
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
result << partial result << partial
@ -161,7 +167,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
parsed_body = nil parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with( stub_request(:post, url).with(
body: body:
proc do |req_body| proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true) parsed_body = JSON.parse(req_body, symbolize_names: true)
@ -244,7 +250,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}, },
}.to_json }.to_json
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body) stub_request(:post, url).to_return(body: body)
result = proxy.generate(prompt, user: Discourse.system_user) result = proxy.generate(prompt, user: Discourse.system_user)
@ -314,7 +320,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
STRING STRING
requested_body = nil requested_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with( stub_request(:post, url).with(
body: body:
proc do |req_body| proc do |req_body|
requested_body = JSON.parse(req_body, symbolize_names: true) requested_body = JSON.parse(req_body, symbolize_names: true)
@ -351,7 +357,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
STRING STRING
parsed_body = nil parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with( stub_request(:post, url).with(
body: body:
proc do |req_body| proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true) parsed_body = JSON.parse(req_body, symbolize_names: true)

View File

@ -130,6 +130,17 @@ end
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) } subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
fab!(:model) do
Fabricate(
:llm_model,
url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest",
name: "gemini-1.5-pro",
provider: "google",
api_key: "ABC",
vision_enabled: true,
)
end
fab!(:user) fab!(:user)
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
@ -144,8 +155,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
end end
it "Supports Vision API" do it "Supports Vision API" do
SiteSetting.ai_gemini_api_key = "ABC"
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",
@ -158,9 +167,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
req_body = nil req_body = nil
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-pro") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = url = "#{model.url}:generateContent?key=ABC"
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=ABC"
stub_request(:post, url).with( stub_request(:post, url).with(
body: body:
@ -202,8 +210,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
end end
it "Can correctly handle streamed responses even if they are chunked badly" do it "Can correctly handle streamed responses even if they are chunked badly" do
SiteSetting.ai_gemini_api_key = "ABC"
data = +"" data = +""
data << "da|ta: |" data << "da|ta: |"
data << gemini_mock.response("Hello").to_json data << gemini_mock.response("Hello").to_json
@ -214,9 +220,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
split = data.split("|") split = data.split("|")
llm = DiscourseAi::Completions::Llm.proxy("google:gemini-1.5-flash") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = url = "#{model.url}:streamGenerateContent?alt=sse&key=ABC"
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:streamGenerateContent?alt=sse&key=ABC"
output = +"" output = +""
gemini_mock.with_chunk_array_support do gemini_mock.with_chunk_array_support do

View File

@ -258,7 +258,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
describe "image support" do describe "image support" do
it "can handle images" do it "can handle images" do
llm = DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") model = Fabricate(:llm_model, provider: "open_ai", vision_enabled: true)
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",

View File

@ -112,43 +112,40 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
"A picture of a cat sitting on a table (#{I18n.t("discourse_ai.ai_helper.image_caption.attribution")})" "A picture of a cat sitting on a table (#{I18n.t("discourse_ai.ai_helper.image_caption.attribution")})"
end end
before { assign_fake_provider_to(:ai_helper_image_caption_model) }
def request_caption(params)
DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
post "/discourse-ai/ai-helper/caption_image", params: params
yield(response)
end
end
context "when logged in as an allowed user" do context "when logged in as an allowed user" do
fab!(:user) { Fabricate(:user, refresh_auto_groups: true) } fab!(:user) { Fabricate(:user, refresh_auto_groups: true) }
before do before do
sign_in(user) sign_in(user)
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_llava_endpoint = "https://example.com"
stub_request(:post, "https://example.com/predictions").to_return( SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
status: 200,
body: { output: caption.gsub(" ", " |").split("|") }.to_json,
)
end end
it "returns the suggested caption for the image" do it "returns the suggested caption for the image" do
post "/discourse-ai/ai-helper/caption_image", request_caption({ image_url: image_url, image_url_type: "long_url" }) do |r|
params: { expect(r.status).to eq(200)
image_url: image_url, expect(r.parsed_body["caption"]).to eq(caption_with_attrs)
image_url_type: "long_url", end
}
expect(response.status).to eq(200)
expect(response.parsed_body["caption"]).to eq(caption_with_attrs)
end end
context "when the image_url is a short_url" do context "when the image_url is a short_url" do
let(:image_url) { upload.short_url } let(:image_url) { upload.short_url }
it "returns the suggested caption for the image" do it "returns the suggested caption for the image" do
post "/discourse-ai/ai-helper/caption_image", request_caption({ image_url: image_url, image_url_type: "short_url" }) do |r|
params: { expect(r.status).to eq(200)
image_url: image_url, expect(r.parsed_body["caption"]).to eq(caption_with_attrs)
image_url_type: "short_url", end
}
expect(response.status).to eq(200)
expect(response.parsed_body["caption"]).to eq(caption_with_attrs)
end end
end end
@ -156,27 +153,25 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
let(:image_url) { "#{Discourse.base_url}#{upload.short_path}" } let(:image_url) { "#{Discourse.base_url}#{upload.short_path}" }
it "returns the suggested caption for the image" do it "returns the suggested caption for the image" do
post "/discourse-ai/ai-helper/caption_image", request_caption({ image_url: image_url, image_url_type: "short_path" }) do |r|
params: { expect(r.status).to eq(200)
image_url: image_url, expect(r.parsed_body["caption"]).to eq(caption_with_attrs)
image_url_type: "short_path", end
}
expect(response.status).to eq(200)
expect(response.parsed_body["caption"]).to eq(caption_with_attrs)
end end
end end
it "returns a 502 error when the completion call fails" do it "returns a 502 error when the completion call fails" do
stub_request(:post, "https://example.com/predictions").to_return(status: 502) DiscourseAi::Completions::Llm.with_prepared_responses(
[DiscourseAi::Completions::Endpoints::Base::CompletionFailed.new],
) do
post "/discourse-ai/ai-helper/caption_image",
params: {
image_url: image_url,
image_url_type: "long_url",
}
post "/discourse-ai/ai-helper/caption_image", expect(response.status).to eq(502)
params: { end
image_url: image_url,
image_url_type: "long_url",
}
expect(response.status).to eq(502)
end end
it "returns a 400 error when the image_url is blank" do it "returns a 400 error when the image_url is blank" do
@ -211,9 +206,10 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
SiteSetting.provider = SiteSettings::DbProvider.new(SiteSetting) SiteSetting.provider = SiteSettings::DbProvider.new(SiteSetting)
setup_s3 setup_s3
stub_s3_store stub_s3_store
assign_fake_provider_to(:ai_helper_image_caption_model)
SiteSetting.secure_uploads = true SiteSetting.secure_uploads = true
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1] SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_llava_endpoint = "https://example.com"
Group.find(SiteSetting.ai_helper_allowed_groups_map.first).add(user) Group.find(SiteSetting.ai_helper_allowed_groups_map.first).add(user)
user.reload user.reload
@ -242,14 +238,11 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "returns a 200 message and caption if user can access the secure upload" do it "returns a 200 message and caption if user can access the secure upload" do
group.add(user) group.add(user)
post "/discourse-ai/ai-helper/caption_image",
params: {
image_url: image_url,
image_url_type: "long_url",
}
expect(response.status).to eq(200) request_caption({ image_url: image_url, image_url_type: "long_url" }) do |r|
expect(response.parsed_body["caption"]).to eq(caption_with_attrs) expect(r.status).to eq(200)
expect(r.parsed_body["caption"]).to eq(caption_with_attrs)
end
end end
context "if the input URL is for a secure upload but not on the secure-uploads path" do context "if the input URL is for a secure upload but not on the secure-uploads path" do
@ -257,13 +250,11 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
it "creates a signed URL properly and makes the caption" do it "creates a signed URL properly and makes the caption" do
group.add(user) group.add(user)
post "/discourse-ai/ai-helper/caption_image",
params: { request_caption({ image_url: image_url, image_url_type: "long_url" }) do |r|
image_url: image_url, expect(r.status).to eq(200)
image_url_type: "long_url", expect(r.parsed_body["caption"]).to eq(caption_with_attrs)
} end
expect(response.status).to eq(200)
expect(response.parsed_body["caption"]).to eq(caption_with_attrs)
end end
end end
end end

View File

@ -21,14 +21,9 @@ RSpec.describe "AI image caption", type: :system, js: true do
before do before do
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user) Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
assign_fake_provider_to(:ai_helper_model) assign_fake_provider_to(:ai_helper_model)
SiteSetting.ai_llava_endpoint = "https://example.com" assign_fake_provider_to(:ai_helper_image_caption_model)
SiteSetting.ai_helper_enabled_features = "image_caption" SiteSetting.ai_helper_enabled_features = "image_caption"
sign_in(user) sign_in(user)
stub_request(:post, "https://example.com/predictions").to_return(
status: 200,
body: { output: caption.gsub(" ", " |").split("|") }.to_json,
)
end end
shared_examples "shows no image caption button" do shared_examples "shows no image caption button" do
@ -53,35 +48,41 @@ RSpec.describe "AI image caption", type: :system, js: true do
context "when triggering caption with AI on desktop" do context "when triggering caption with AI on desktop" do
it "should show an image caption in an input field" do it "should show an image caption in an input field" do
visit("/latest") DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
page.find("#create-topic").click visit("/latest")
attach_file([file_path]) { composer.click_toolbar_button("upload") } page.find("#create-topic").click
popup.click_generate_caption attach_file([file_path]) { composer.click_toolbar_button("upload") }
expect(popup.has_caption_popup_value?(caption_with_attrs)).to eq(true) popup.click_generate_caption
popup.save_caption expect(popup.has_caption_popup_value?(caption_with_attrs)).to eq(true)
wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs } popup.save_caption
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs) wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs }
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs)
end
end end
it "should allow you to cancel a caption request" do it "should allow you to cancel a caption request" do
visit("/latest") DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
page.find("#create-topic").click visit("/latest")
attach_file([file_path]) { composer.click_toolbar_button("upload") } page.find("#create-topic").click
popup.click_generate_caption attach_file([file_path]) { composer.click_toolbar_button("upload") }
popup.cancel_caption popup.click_generate_caption
expect(popup).to have_no_disabled_generate_button popup.cancel_caption
expect(popup).to have_no_disabled_generate_button
end
end end
end end
context "when triggering caption with AI on mobile", mobile: true do context "when triggering caption with AI on mobile", mobile: true do
it "should show update the image alt text with the caption" do it "should show update the image alt text with the caption" do
visit("/latest") DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
page.find("#create-topic").click visit("/latest")
attach_file([file_path]) { page.find(".mobile-file-upload").click } page.find("#create-topic").click
page.find(".mobile-preview").click attach_file([file_path]) { page.find(".mobile-file-upload").click }
popup.click_generate_caption page.find(".mobile-preview").click
wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs } popup.click_generate_caption
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs) wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs }
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs)
end
end end
end end
@ -125,15 +126,17 @@ RSpec.describe "AI image caption", type: :system, js: true do
end end
it "should auto caption the existing images and update the preference when dialog is accepted" do it "should auto caption the existing images and update the preference when dialog is accepted" do
visit("/latest") DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
page.find("#create-topic").click visit("/latest")
attach_file([file_path]) { composer.click_toolbar_button("upload") } page.find("#create-topic").click
wait_for { composer.has_no_in_progress_uploads? } attach_file([file_path]) { composer.click_toolbar_button("upload") }
composer.fill_title("I love using Discourse! It is my favorite forum software") wait_for { composer.has_no_in_progress_uploads? }
composer.create composer.fill_title("I love using Discourse! It is my favorite forum software")
dialog.click_yes composer.create
wait_for(timeout: 100) { page.find("#post_1 .cooked img")["alt"] == caption_with_attrs } dialog.click_yes
expect(page.find("#post_1 .cooked img")["alt"]).to eq(caption_with_attrs) wait_for(timeout: 100) { page.find("#post_1 .cooked img")["alt"] == caption_with_attrs }
expect(page.find("#post_1 .cooked img")["alt"]).to eq(caption_with_attrs)
end
end end
end end
@ -142,14 +145,16 @@ RSpec.describe "AI image caption", type: :system, js: true do
skip "TODO: Fix auto_image_caption user option not present in testing environment?" do skip "TODO: Fix auto_image_caption user option not present in testing environment?" do
it "should auto caption the image after uploading" do it "should auto caption the image after uploading" do
visit("/latest") DiscourseAi::Completions::Llm.with_prepared_responses([caption]) do
page.find("#create-topic").click visit("/latest")
attach_file([Rails.root.join("spec/fixtures/images/logo.jpg")]) do page.find("#create-topic").click
composer.click_toolbar_button("upload") attach_file([Rails.root.join("spec/fixtures/images/logo.jpg")]) do
composer.click_toolbar_button("upload")
end
wait_for { composer.has_no_in_progress_uploads? }
wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs }
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs)
end end
wait_for { composer.has_no_in_progress_uploads? }
wait_for { page.find(".image-wrapper img")["alt"] == caption_with_attrs }
expect(page.find(".image-wrapper img")["alt"]).to eq(caption_with_attrs)
end end
end end
end end