FEATURE: Add DallE support to AI helper's illustrate post (#404)
This commit is contained in:
parent
23b2809638
commit
7201d482d5
|
@ -59,6 +59,7 @@ en:
|
|||
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
|
||||
ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title."
|
||||
ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents."
|
||||
ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature"
|
||||
|
||||
ai_embeddings_enabled: "Enable the embeddings module."
|
||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||
|
@ -129,7 +130,9 @@ en:
|
|||
explain: "Explain"
|
||||
illustrate_post: "Illustrate Post"
|
||||
painter:
|
||||
attribution: "Image by Stable Diffusion XL"
|
||||
attribution:
|
||||
stable_diffusion_xl: "Image by Stable Diffusion XL"
|
||||
dall_e_3: "Image by DALL-E 3"
|
||||
|
||||
ai_bot:
|
||||
placeholder_reply: "I will reply shortly..."
|
||||
|
|
|
@ -199,6 +199,13 @@ discourse_ai:
|
|||
default: false
|
||||
ai_helper_automatic_chat_thread_title_delay:
|
||||
default: 5
|
||||
ai_helper_illustrate_post_model:
|
||||
default: disabled
|
||||
type: enum
|
||||
choices:
|
||||
- stable_diffusion_xl
|
||||
- dall_e_3
|
||||
- disabled
|
||||
|
||||
ai_embeddings_enabled:
|
||||
default: false
|
||||
|
|
|
@ -11,9 +11,11 @@ module DiscourseAi
|
|||
prompts = [cp.enabled_by_name(name_filter)]
|
||||
else
|
||||
prompts = cp.where(enabled: true)
|
||||
# Only show the illustrate_post prompt if the API key is present
|
||||
# Hide illustrate_post if disabled
|
||||
prompts =
|
||||
prompts.where.not(name: "illustrate_post") if !SiteSetting.ai_stability_api_key.present?
|
||||
prompts.where.not(
|
||||
name: "illustrate_post",
|
||||
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
|
||||
end
|
||||
|
||||
prompts.map do |prompt|
|
||||
|
|
|
@ -3,35 +3,59 @@
|
|||
module DiscourseAi
|
||||
module AiHelper
|
||||
class Painter
|
||||
def commission_thumbnails(theme, user)
|
||||
stable_diffusion_prompt = difussion_prompt(theme, user)
|
||||
def commission_thumbnails(input, user)
|
||||
return [] if input.blank?
|
||||
|
||||
model = SiteSetting.ai_helper_illustrate_post_model
|
||||
attribution = "discourse_ai.ai_helper.painter.attribution.#{model}"
|
||||
|
||||
if model == "stable_diffusion_xl"
|
||||
stable_diffusion_prompt = difussion_prompt(input, user)
|
||||
return [] if stable_diffusion_prompt.blank?
|
||||
|
||||
base64_artifacts =
|
||||
artifacts =
|
||||
DiscourseAi::Inference::StabilityGenerator
|
||||
.perform!(stable_diffusion_prompt)
|
||||
.dig(:artifacts)
|
||||
.to_a
|
||||
.map { |art| art[:base64] }
|
||||
|
||||
base64_artifacts.each_with_index.map do |artifact, i|
|
||||
base64_to_image(artifacts, user.id)
|
||||
elsif model == "dall_e_3"
|
||||
api_key = SiteSetting.ai_openai_api_key
|
||||
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||
|
||||
artifacts =
|
||||
DiscourseAi::Inference::OpenAiImageGenerator
|
||||
.perform!(input, api_key: api_key, api_url: api_url)
|
||||
.dig(:data)
|
||||
.to_a
|
||||
.map { |art| art[:b64_json] }
|
||||
|
||||
base64_to_image(artifacts, user.id)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def base64_to_image(artifacts, user_id)
|
||||
attribution =
|
||||
I18n.t(
|
||||
"discourse_ai.ai_helper.painter.attribution.#{SiteSetting.ai_helper_illustrate_post_model}",
|
||||
)
|
||||
|
||||
artifacts.each_with_index.map do |art, i|
|
||||
f = Tempfile.new("v1_txt2img_#{i}.png")
|
||||
f.binmode
|
||||
f.write(Base64.decode64(artifact))
|
||||
f.write(Base64.decode64(art))
|
||||
f.rewind
|
||||
upload =
|
||||
UploadCreator.new(f, I18n.t("discourse_ai.ai_helper.painter.attribution")).create_for(
|
||||
user.id,
|
||||
)
|
||||
upload = UploadCreator.new(f, attribution).create_for(user_id)
|
||||
f.unlink
|
||||
|
||||
UploadSerializer.new(upload, root: false)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def difussion_prompt(text, user)
|
||||
prompt = { insts: <<~TEXT, input: text }
|
||||
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||
|
|
|
@ -36,8 +36,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
|
|||
end
|
||||
end
|
||||
|
||||
context "when stability API key is present" do
|
||||
before { SiteSetting.ai_stability_api_key = "foo" }
|
||||
context "when illustrate post model is enabled" do
|
||||
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
|
||||
|
||||
it "returns the illustrate_post prompt in the list of all prompts" do
|
||||
prompts = subject.available_prompts
|
||||
|
|
|
@ -8,9 +8,14 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
|
|||
before do
|
||||
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
||||
SiteSetting.ai_stability_api_key = "abc"
|
||||
SiteSetting.ai_openai_api_key = "abc"
|
||||
SiteSetting.ai_openai_dall_e_3_url = "https://api.openai.com/v1/images/generations"
|
||||
end
|
||||
|
||||
describe "#commission_thumbnails" do
|
||||
context "when illustrate post model is stable_diffusion_xl" do
|
||||
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
|
||||
|
||||
let(:artifacts) do
|
||||
%w[
|
||||
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
||||
|
@ -40,9 +45,44 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
|
|||
|
||||
thumbnail_urls = Upload.last(4).map(&:short_url)
|
||||
|
||||
expect(thumbnails.map { |upload_serializer| upload_serializer.short_url }).to contain_exactly(
|
||||
*thumbnail_urls,
|
||||
)
|
||||
expect(
|
||||
thumbnails.map { |upload_serializer| upload_serializer.short_url },
|
||||
).to contain_exactly(*thumbnail_urls)
|
||||
end
|
||||
end
|
||||
|
||||
context "when illustrate post model is dall_e_3" do
|
||||
before { SiteSetting.ai_helper_illustrate_post_model = "dall_e_3" }
|
||||
|
||||
let(:artifacts) do
|
||||
%w[
|
||||
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
||||
]
|
||||
end
|
||||
|
||||
let(:raw_content) do
|
||||
"Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas."
|
||||
end
|
||||
|
||||
it "returns an image sample" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
data = [{ b64_json: artifacts.first, revised_prompt: "colors on a canvas" }]
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body, symbolize_names: true)
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: { data: data }.to_json)
|
||||
|
||||
thumbnails = subject.commission_thumbnails(raw_content, user)
|
||||
thumbnail_urls = Upload.last(1).map(&:short_url)
|
||||
|
||||
expect(
|
||||
thumbnails.map { |upload_serializer| upload_serializer.short_url },
|
||||
).to contain_exactly(*thumbnail_urls)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -138,7 +138,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
sign_in(user)
|
||||
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
||||
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
|
||||
SiteSetting.ai_stability_api_key = "foo"
|
||||
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
|
||||
end
|
||||
|
||||
it "returns a list of prompts when no name_filter is provided" do
|
||||
|
|
Loading…
Reference in New Issue