mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FEATURE: Add DallE support to AI helper's illustrate post (#404)
This commit is contained in:
parent
23b2809638
commit
7201d482d5
@ -59,6 +59,7 @@ en:
|
|||||||
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
|
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
|
||||||
ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title."
|
ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title."
|
||||||
ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents."
|
ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents."
|
||||||
|
ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature"
|
||||||
|
|
||||||
ai_embeddings_enabled: "Enable the embeddings module."
|
ai_embeddings_enabled: "Enable the embeddings module."
|
||||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||||
@ -129,7 +130,9 @@ en:
|
|||||||
explain: "Explain"
|
explain: "Explain"
|
||||||
illustrate_post: "Illustrate Post"
|
illustrate_post: "Illustrate Post"
|
||||||
painter:
|
painter:
|
||||||
attribution: "Image by Stable Diffusion XL"
|
attribution:
|
||||||
|
stable_diffusion_xl: "Image by Stable Diffusion XL"
|
||||||
|
dall_e_3: "Image by DALL-E 3"
|
||||||
|
|
||||||
ai_bot:
|
ai_bot:
|
||||||
placeholder_reply: "I will reply shortly..."
|
placeholder_reply: "I will reply shortly..."
|
||||||
|
@ -199,6 +199,13 @@ discourse_ai:
|
|||||||
default: false
|
default: false
|
||||||
ai_helper_automatic_chat_thread_title_delay:
|
ai_helper_automatic_chat_thread_title_delay:
|
||||||
default: 5
|
default: 5
|
||||||
|
ai_helper_illustrate_post_model:
|
||||||
|
default: disabled
|
||||||
|
type: enum
|
||||||
|
choices:
|
||||||
|
- stable_diffusion_xl
|
||||||
|
- dall_e_3
|
||||||
|
- disabled
|
||||||
|
|
||||||
ai_embeddings_enabled:
|
ai_embeddings_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
@ -11,9 +11,11 @@ module DiscourseAi
|
|||||||
prompts = [cp.enabled_by_name(name_filter)]
|
prompts = [cp.enabled_by_name(name_filter)]
|
||||||
else
|
else
|
||||||
prompts = cp.where(enabled: true)
|
prompts = cp.where(enabled: true)
|
||||||
# Only show the illustrate_post prompt if the API key is present
|
# Hide illustrate_post if disabled
|
||||||
prompts =
|
prompts =
|
||||||
prompts.where.not(name: "illustrate_post") if !SiteSetting.ai_stability_api_key.present?
|
prompts.where.not(
|
||||||
|
name: "illustrate_post",
|
||||||
|
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
|
||||||
end
|
end
|
||||||
|
|
||||||
prompts.map do |prompt|
|
prompts.map do |prompt|
|
||||||
|
@ -3,35 +3,59 @@
|
|||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiHelper
|
module AiHelper
|
||||||
class Painter
|
class Painter
|
||||||
def commission_thumbnails(theme, user)
|
def commission_thumbnails(input, user)
|
||||||
stable_diffusion_prompt = difussion_prompt(theme, user)
|
return [] if input.blank?
|
||||||
|
|
||||||
|
model = SiteSetting.ai_helper_illustrate_post_model
|
||||||
|
attribution = "discourse_ai.ai_helper.painter.attribution.#{model}"
|
||||||
|
|
||||||
|
if model == "stable_diffusion_xl"
|
||||||
|
stable_diffusion_prompt = difussion_prompt(input, user)
|
||||||
return [] if stable_diffusion_prompt.blank?
|
return [] if stable_diffusion_prompt.blank?
|
||||||
|
|
||||||
base64_artifacts =
|
artifacts =
|
||||||
DiscourseAi::Inference::StabilityGenerator
|
DiscourseAi::Inference::StabilityGenerator
|
||||||
.perform!(stable_diffusion_prompt)
|
.perform!(stable_diffusion_prompt)
|
||||||
.dig(:artifacts)
|
.dig(:artifacts)
|
||||||
.to_a
|
.to_a
|
||||||
.map { |art| art[:base64] }
|
.map { |art| art[:base64] }
|
||||||
|
|
||||||
base64_artifacts.each_with_index.map do |artifact, i|
|
base64_to_image(artifacts, user.id)
|
||||||
|
elsif model == "dall_e_3"
|
||||||
|
api_key = SiteSetting.ai_openai_api_key
|
||||||
|
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||||
|
|
||||||
|
artifacts =
|
||||||
|
DiscourseAi::Inference::OpenAiImageGenerator
|
||||||
|
.perform!(input, api_key: api_key, api_url: api_url)
|
||||||
|
.dig(:data)
|
||||||
|
.to_a
|
||||||
|
.map { |art| art[:b64_json] }
|
||||||
|
|
||||||
|
base64_to_image(artifacts, user.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def base64_to_image(artifacts, user_id)
|
||||||
|
attribution =
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.ai_helper.painter.attribution.#{SiteSetting.ai_helper_illustrate_post_model}",
|
||||||
|
)
|
||||||
|
|
||||||
|
artifacts.each_with_index.map do |art, i|
|
||||||
f = Tempfile.new("v1_txt2img_#{i}.png")
|
f = Tempfile.new("v1_txt2img_#{i}.png")
|
||||||
f.binmode
|
f.binmode
|
||||||
f.write(Base64.decode64(artifact))
|
f.write(Base64.decode64(art))
|
||||||
f.rewind
|
f.rewind
|
||||||
upload =
|
upload = UploadCreator.new(f, attribution).create_for(user_id)
|
||||||
UploadCreator.new(f, I18n.t("discourse_ai.ai_helper.painter.attribution")).create_for(
|
|
||||||
user.id,
|
|
||||||
)
|
|
||||||
f.unlink
|
f.unlink
|
||||||
|
|
||||||
UploadSerializer.new(upload, root: false)
|
UploadSerializer.new(upload, root: false)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def difussion_prompt(text, user)
|
def difussion_prompt(text, user)
|
||||||
prompt = { insts: <<~TEXT, input: text }
|
prompt = { insts: <<~TEXT, input: text }
|
||||||
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||||
|
@ -36,8 +36,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "when stability API key is present" do
|
context "when illustrate post model is enabled" do
|
||||||
before { SiteSetting.ai_stability_api_key = "foo" }
|
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
|
||||||
|
|
||||||
it "returns the illustrate_post prompt in the list of all prompts" do
|
it "returns the illustrate_post prompt in the list of all prompts" do
|
||||||
prompts = subject.available_prompts
|
prompts = subject.available_prompts
|
||||||
|
@ -8,9 +8,14 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
|
|||||||
before do
|
before do
|
||||||
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
||||||
SiteSetting.ai_stability_api_key = "abc"
|
SiteSetting.ai_stability_api_key = "abc"
|
||||||
|
SiteSetting.ai_openai_api_key = "abc"
|
||||||
|
SiteSetting.ai_openai_dall_e_3_url = "https://api.openai.com/v1/images/generations"
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#commission_thumbnails" do
|
describe "#commission_thumbnails" do
|
||||||
|
context "when illustrate post model is stable_diffusion_xl" do
|
||||||
|
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
|
||||||
|
|
||||||
let(:artifacts) do
|
let(:artifacts) do
|
||||||
%w[
|
%w[
|
||||||
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
||||||
@ -40,9 +45,44 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
|
|||||||
|
|
||||||
thumbnail_urls = Upload.last(4).map(&:short_url)
|
thumbnail_urls = Upload.last(4).map(&:short_url)
|
||||||
|
|
||||||
expect(thumbnails.map { |upload_serializer| upload_serializer.short_url }).to contain_exactly(
|
expect(
|
||||||
*thumbnail_urls,
|
thumbnails.map { |upload_serializer| upload_serializer.short_url },
|
||||||
)
|
).to contain_exactly(*thumbnail_urls)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when illustrate post model is dall_e_3" do
|
||||||
|
before { SiteSetting.ai_helper_illustrate_post_model = "dall_e_3" }
|
||||||
|
|
||||||
|
let(:artifacts) do
|
||||||
|
%w[
|
||||||
|
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:raw_content) do
|
||||||
|
"Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas."
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns an image sample" do
|
||||||
|
post = Fabricate(:post)
|
||||||
|
|
||||||
|
data = [{ b64_json: artifacts.first, revised_prompt: "colors on a canvas" }]
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body, symbolize_names: true)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: { data: data }.to_json)
|
||||||
|
|
||||||
|
thumbnails = subject.commission_thumbnails(raw_content, user)
|
||||||
|
thumbnail_urls = Upload.last(1).map(&:short_url)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
thumbnails.map { |upload_serializer| upload_serializer.short_url },
|
||||||
|
).to contain_exactly(*thumbnail_urls)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -138,7 +138,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||||||
sign_in(user)
|
sign_in(user)
|
||||||
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
|
||||||
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
|
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
|
||||||
SiteSetting.ai_stability_api_key = "foo"
|
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
|
||||||
end
|
end
|
||||||
|
|
||||||
it "returns a list of prompts when no name_filter is provided" do
|
it "returns a list of prompts when no name_filter is provided" do
|
||||||
|
Loading…
x
Reference in New Issue
Block a user