diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index f58cb080..ce4d3aec 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -59,6 +59,7 @@ en: ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper." ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title." ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents." + ai_helper_illustrate_post_model: "Model to use for the composer AI helper's illustrate post feature" ai_embeddings_enabled: "Enable the embeddings module." ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" @@ -129,7 +130,9 @@ en: explain: "Explain" illustrate_post: "Illustrate Post" painter: - attribution: "Image by Stable Diffusion XL" + attribution: + stable_diffusion_xl: "Image by Stable Diffusion XL" + dall_e_3: "Image by DALL-E 3" ai_bot: placeholder_reply: "I will reply shortly..." diff --git a/config/settings.yml b/config/settings.yml index 8096e4cd..d6723eb0 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -199,6 +199,13 @@ discourse_ai: default: false ai_helper_automatic_chat_thread_title_delay: default: 5 + ai_helper_illustrate_post_model: + default: disabled + type: enum + choices: + - stable_diffusion_xl + - dall_e_3 + - disabled ai_embeddings_enabled: default: false diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index 038ef4bd..12a70268 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -11,9 +11,11 @@ module DiscourseAi prompts = [cp.enabled_by_name(name_filter)] else prompts = cp.where(enabled: true) - # Only show the illustrate_post prompt if the API key is present + # Hide illustrate_post if disabled prompts = - prompts.where.not(name: "illustrate_post") if !SiteSetting.ai_stability_api_key.present? + prompts.where.not( + name: "illustrate_post", + ) if SiteSetting.ai_helper_illustrate_post_model == "disabled" end prompts.map do |prompt| diff --git a/lib/ai_helper/painter.rb b/lib/ai_helper/painter.rb index ec8358ee..56e51d14 100644 --- a/lib/ai_helper/painter.rb +++ b/lib/ai_helper/painter.rb @@ -3,35 +3,59 @@ module DiscourseAi module AiHelper class Painter - def commission_thumbnails(theme, user) - stable_diffusion_prompt = difussion_prompt(theme, user) + def commission_thumbnails(input, user) + return [] if input.blank? - return [] if stable_diffusion_prompt.blank? + model = SiteSetting.ai_helper_illustrate_post_model + attribution = "discourse_ai.ai_helper.painter.attribution.#{model}" - base64_artifacts = - DiscourseAi::Inference::StabilityGenerator - .perform!(stable_diffusion_prompt) - .dig(:artifacts) - .to_a - .map { |art| art[:base64] } + if model == "stable_diffusion_xl" + stable_diffusion_prompt = difussion_prompt(input, user) + return [] if stable_diffusion_prompt.blank? - base64_artifacts.each_with_index.map do |artifact, i| + artifacts = + DiscourseAi::Inference::StabilityGenerator + .perform!(stable_diffusion_prompt) + .dig(:artifacts) + .to_a + .map { |art| art[:base64] } + + base64_to_image(artifacts, user.id) + elsif model == "dall_e_3" + api_key = SiteSetting.ai_openai_api_key + api_url = SiteSetting.ai_openai_dall_e_3_url + + artifacts = + DiscourseAi::Inference::OpenAiImageGenerator + .perform!(input, api_key: api_key, api_url: api_url) + .dig(:data) + .to_a + .map { |art| art[:b64_json] } + + base64_to_image(artifacts, user.id) + end + end + + private + + def base64_to_image(artifacts, user_id) + attribution = + I18n.t( + "discourse_ai.ai_helper.painter.attribution.#{SiteSetting.ai_helper_illustrate_post_model}", + ) + + artifacts.each_with_index.map do |art, i| f = Tempfile.new("v1_txt2img_#{i}.png") f.binmode - f.write(Base64.decode64(artifact)) + f.write(Base64.decode64(art)) f.rewind - upload = - UploadCreator.new(f, I18n.t("discourse_ai.ai_helper.painter.attribution")).create_for( - user.id, - ) + upload = UploadCreator.new(f, attribution).create_for(user_id) f.unlink UploadSerializer.new(upload, root: false) end end - private - def difussion_prompt(text, user) prompt = { insts: <<~TEXT, input: text } Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative. diff --git a/spec/lib/modules/ai_helper/assistant_spec.rb b/spec/lib/modules/ai_helper/assistant_spec.rb index 75085bdd..4f40aa04 100644 --- a/spec/lib/modules/ai_helper/assistant_spec.rb +++ b/spec/lib/modules/ai_helper/assistant_spec.rb @@ -36,8 +36,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do end end - context "when stability API key is present" do - before { SiteSetting.ai_stability_api_key = "foo" } + context "when illustrate post model is enabled" do + before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" } it "returns the illustrate_post prompt in the list of all prompts" do prompts = subject.available_prompts diff --git a/spec/lib/modules/ai_helper/painter_spec.rb b/spec/lib/modules/ai_helper/painter_spec.rb index e0b0e665..bf11d24b 100644 --- a/spec/lib/modules/ai_helper/painter_spec.rb +++ b/spec/lib/modules/ai_helper/painter_spec.rb @@ -8,41 +8,81 @@ RSpec.describe DiscourseAi::AiHelper::Painter do before do SiteSetting.ai_stability_api_url = "https://api.stability.dev" SiteSetting.ai_stability_api_key = "abc" + SiteSetting.ai_openai_api_key = "abc" + SiteSetting.ai_openai_dall_e_3_url = "https://api.openai.com/v1/images/generations" end describe "#commission_thumbnails" do - let(:artifacts) do - %w[ - iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg== - iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC - iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC - iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC - ] + context "when illustrate post model is stable_diffusion_xl" do + before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" } + + let(:artifacts) do + %w[ + iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg== + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC + ] + end + + let(:raw_content) do + "Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas." + end + + let(:expected_image_prompt) { <<~TEXT.strip } + Visualize a vibrant scene of an inkwell bursting, spreading colors across a blank canvas, + embodying words in tangible forms, symbolizing the rhythm and emotion evoked by poetry, + under the soft glow of a full moon. + TEXT + + it "returns 4 samples" do + StableDiffusionStubs.new.stub_response(expected_image_prompt, artifacts) + + thumbnails = + DiscourseAi::Completions::Llm.with_prepared_responses([expected_image_prompt]) do + thumbnails = subject.commission_thumbnails(raw_content, user) + end + + thumbnail_urls = Upload.last(4).map(&:short_url) + + expect( + thumbnails.map { |upload_serializer| upload_serializer.short_url }, + ).to contain_exactly(*thumbnail_urls) + end end - let(:raw_content) do - "Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas." - end + context "when illustrate post model is dall_e_3" do + before { SiteSetting.ai_helper_illustrate_post_model = "dall_e_3" } - let(:expected_image_prompt) { <<~TEXT.strip } - Visualize a vibrant scene of an inkwell bursting, spreading colors across a blank canvas, - embodying words in tangible forms, symbolizing the rhythm and emotion evoked by poetry, - under the soft glow of a full moon. - TEXT + let(:artifacts) do + %w[ + iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg== + ] + end - it "returns 4 samples" do - StableDiffusionStubs.new.stub_response(expected_image_prompt, artifacts) + let(:raw_content) do + "Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas." + end - thumbnails = - DiscourseAi::Completions::Llm.with_prepared_responses([expected_image_prompt]) do - thumbnails = subject.commission_thumbnails(raw_content, user) - end + it "returns an image sample" do + post = Fabricate(:post) - thumbnail_urls = Upload.last(4).map(&:short_url) + data = [{ b64_json: artifacts.first, revised_prompt: "colors on a canvas" }] + WebMock + .stub_request(:post, "https://api.openai.com/v1/images/generations") + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + true + end + .to_return(status: 200, body: { data: data }.to_json) - expect(thumbnails.map { |upload_serializer| upload_serializer.short_url }).to contain_exactly( - *thumbnail_urls, - ) + thumbnails = subject.commission_thumbnails(raw_content, user) + thumbnail_urls = Upload.last(1).map(&:short_url) + + expect( + thumbnails.map { |upload_serializer| upload_serializer.short_url }, + ).to contain_exactly(*thumbnail_urls) + end end end end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 5e034caa..e7238399 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -138,7 +138,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do sign_in(user) user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]] SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1] - SiteSetting.ai_stability_api_key = "foo" + SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" end it "returns a list of prompts when no name_filter is provided" do