FEATURE: Add DallE support to AI helper's illustrate post (#404)

This commit is contained in:
Keegan George 2024-01-05 09:03:23 -08:00 committed by GitHub
parent 23b2809638
commit 7201d482d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 124 additions and 48 deletions

View File

@ -59,6 +59,7 @@ en:
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title."
ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents."
ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature"
ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
@ -129,7 +130,9 @@ en:
explain: "Explain"
illustrate_post: "Illustrate Post"
painter:
attribution: "Image by Stable Diffusion XL"
attribution:
stable_diffusion_xl: "Image by Stable Diffusion XL"
dall_e_3: "Image by DALL-E 3"
ai_bot:
placeholder_reply: "I will reply shortly..."

View File

@ -199,6 +199,13 @@ discourse_ai:
default: false
ai_helper_automatic_chat_thread_title_delay:
default: 5
ai_helper_illustrate_post_model:
default: disabled
type: enum
choices:
- stable_diffusion_xl
- dall_e_3
- disabled
ai_embeddings_enabled:
default: false

View File

@ -11,9 +11,11 @@ module DiscourseAi
prompts = [cp.enabled_by_name(name_filter)]
else
prompts = cp.where(enabled: true)
# Only show the illustrate_post prompt if the API key is present
# Hide illustrate_post if disabled
prompts =
prompts.where.not(name: "illustrate_post") if !SiteSetting.ai_stability_api_key.present?
prompts.where.not(
name: "illustrate_post",
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
end
prompts.map do |prompt|

View File

@ -3,35 +3,59 @@
module DiscourseAi
module AiHelper
class Painter
def commission_thumbnails(theme, user)
stable_diffusion_prompt = difussion_prompt(theme, user)
def commission_thumbnails(input, user)
return [] if input.blank?
model = SiteSetting.ai_helper_illustrate_post_model
attribution = "discourse_ai.ai_helper.painter.attribution.#{model}"
if model == "stable_diffusion_xl"
stable_diffusion_prompt = difussion_prompt(input, user)
return [] if stable_diffusion_prompt.blank?
base64_artifacts =
artifacts =
DiscourseAi::Inference::StabilityGenerator
.perform!(stable_diffusion_prompt)
.dig(:artifacts)
.to_a
.map { |art| art[:base64] }
base64_artifacts.each_with_index.map do |artifact, i|
base64_to_image(artifacts, user.id)
elsif model == "dall_e_3"
api_key = SiteSetting.ai_openai_api_key
api_url = SiteSetting.ai_openai_dall_e_3_url
artifacts =
DiscourseAi::Inference::OpenAiImageGenerator
.perform!(input, api_key: api_key, api_url: api_url)
.dig(:data)
.to_a
.map { |art| art[:b64_json] }
base64_to_image(artifacts, user.id)
end
end
private
def base64_to_image(artifacts, user_id)
attribution =
I18n.t(
"discourse_ai.ai_helper.painter.attribution.#{SiteSetting.ai_helper_illustrate_post_model}",
)
artifacts.each_with_index.map do |art, i|
f = Tempfile.new("v1_txt2img_#{i}.png")
f.binmode
f.write(Base64.decode64(artifact))
f.write(Base64.decode64(art))
f.rewind
upload =
UploadCreator.new(f, I18n.t("discourse_ai.ai_helper.painter.attribution")).create_for(
user.id,
)
upload = UploadCreator.new(f, attribution).create_for(user_id)
f.unlink
UploadSerializer.new(upload, root: false)
end
end
private
def difussion_prompt(text, user)
prompt = { insts: <<~TEXT, input: text }
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.

View File

@ -36,8 +36,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
end
end
context "when stability API key is present" do
before { SiteSetting.ai_stability_api_key = "foo" }
context "when illustrate post model is enabled" do
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
it "returns the illustrate_post prompt in the list of all prompts" do
prompts = subject.available_prompts

View File

@ -8,9 +8,14 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
before do
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
SiteSetting.ai_stability_api_key = "abc"
SiteSetting.ai_openai_api_key = "abc"
SiteSetting.ai_openai_dall_e_3_url = "https://api.openai.com/v1/images/generations"
end
describe "#commission_thumbnails" do
context "when illustrate post model is stable_diffusion_xl" do
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
let(:artifacts) do
%w[
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
@ -40,9 +45,44 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
thumbnail_urls = Upload.last(4).map(&:short_url)
expect(thumbnails.map { |upload_serializer| upload_serializer.short_url }).to contain_exactly(
*thumbnail_urls,
)
expect(
thumbnails.map { |upload_serializer| upload_serializer.short_url },
).to contain_exactly(*thumbnail_urls)
end
end
context "when illustrate post model is dall_e_3" do
before { SiteSetting.ai_helper_illustrate_post_model = "dall_e_3" }
let(:artifacts) do
%w[
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
]
end
let(:raw_content) do
"Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas."
end
it "returns an image sample" do
post = Fabricate(:post)
data = [{ b64_json: artifacts.first, revised_prompt: "colors on a canvas" }]
WebMock
.stub_request(:post, "https://api.openai.com/v1/images/generations")
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
true
end
.to_return(status: 200, body: { data: data }.to_json)
thumbnails = subject.commission_thumbnails(raw_content, user)
thumbnail_urls = Upload.last(1).map(&:short_url)
expect(
thumbnails.map { |upload_serializer| upload_serializer.short_url },
).to contain_exactly(*thumbnail_urls)
end
end
end
end

View File

@ -138,7 +138,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_stability_api_key = "foo"
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
end
it "returns a list of prompts when no name_filter is provided" do