diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index a17df448..fdfa0b4e 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -6,6 +6,7 @@ module DiscourseAi requires_plugin ::DiscourseAi::PLUGIN_NAME requires_login before_action :ensure_can_request_suggestions + before_action :rate_limiter_performed!, except: %i[prompts] def prompts render json: @@ -17,19 +18,13 @@ module DiscourseAi end def suggest - raise Discourse::InvalidParameters.new(:text) if params[:text].blank? + input = get_text_param! prompt = CompletionPrompt.find_by(id: params[:mode]) raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? - RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! - hijack do - render json: - DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt( - prompt, - params[:text], - ), + render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -40,7 +35,7 @@ module DiscourseAi end def suggest_title - raise Discourse::InvalidParameters.new(:text) if params[:text].blank? + input = get_text_param! llm_prompt = DiscourseAi::AiHelper::LlmPrompt @@ -50,14 +45,8 @@ module DiscourseAi prompt = CompletionPrompt.find_by(id: llm_prompt[:id]) raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? - RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! - hijack do - render json: - DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt( - prompt, - params[:text], - ), + render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -68,30 +57,39 @@ module DiscourseAi end def suggest_category - raise Discourse::InvalidParameters.new(:text) if params[:text].blank? + input = get_text_param! - RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! - - render json: - DiscourseAi::AiHelper::SemanticCategorizer.new( - params[:text], - current_user, - ).categories, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories, status: 200 end def suggest_tags - raise Discourse::InvalidParameters.new(:text) if params[:text].blank? + input = get_text_param! - RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! - - render json: - DiscourseAi::AiHelper::SemanticCategorizer.new(params[:text], current_user).tags, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags, status: 200 end + def suggest_thumbnails + input = get_text_param! + + hijack do + thumbnails = DiscourseAi::AiHelper::Painter.new.commission_thumbnails(input, current_user) + + render json: { thumbnails: thumbnails }, status: 200 + end + end + private + def get_text_param! + params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? } + end + + def rate_limiter_performed! + RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! + end + def ensure_can_request_suggestions user_group_ids = current_user.group_ids diff --git a/config/routes.rb b/config/routes.rb index 51b7d506..ced1ca3c 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do post "suggest_title" => "assistant#suggest_title" post "suggest_category" => "assistant#suggest_category" post "suggest_tags" => "assistant#suggest_tags" + post "suggest_thumbnails" => "assistant#suggest_thumbnails" end scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do diff --git a/lib/modules/ai_helper/entry_point.rb b/lib/modules/ai_helper/entry_point.rb index 14dbf29c..dcd012f6 100644 --- a/lib/modules/ai_helper/entry_point.rb +++ b/lib/modules/ai_helper/entry_point.rb @@ -5,6 +5,7 @@ module DiscourseAi def load_files require_relative "llm_prompt" require_relative "semantic_categorizer" + require_relative "painter" end def inject_into(plugin) diff --git a/lib/modules/ai_helper/painter.rb b/lib/modules/ai_helper/painter.rb new file mode 100644 index 00000000..f62362c2 --- /dev/null +++ b/lib/modules/ai_helper/painter.rb @@ -0,0 +1,77 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiHelper + class Painter + def commission_thumbnails(theme, user) + stable_diffusion_prompt = difussion_prompt(theme) + + return [] if stable_diffusion_prompt.blank? + + base64_artifacts = + DiscourseAi::Inference::StabilityGenerator + .perform!(stable_diffusion_prompt) + .dig(:artifacts) + .to_a + .map { |art| art[:base64] } + + base64_artifacts.each_with_index.map do |artifact, i| + f = Tempfile.new("v1_txt2img_#{i}.png") + f.binmode + f.write(Base64.decode64(artifact)) + f.rewind + upload = UploadCreator.new(f, "ai_helper_image.png").create_for(user.id) + f.unlink + + upload.short_url + end + end + + private + + def difussion_prompt(text) + llm_prompt = LlmPrompt.new + prompt_for_provider = + completion_prompts.find { |prompt| prompt.provider == llm_prompt.enabled_provider } + + return "" if prompt_for_provider.nil? + + llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first + end + + def completion_prompts + [ + CompletionPrompt.new( + provider: "anthropic", + prompt_type: CompletionPrompt.prompt_types[:text], + messages: [{ role: "Human", content: <<~TEXT }], + Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative. + The post is provided between tags and the Stable Diffusion prompt string should be returned between tags. + TEXT + ), + CompletionPrompt.new( + provider: "openai", + prompt_type: CompletionPrompt.prompt_types[:text], + messages: [{ role: "system", content: <<~TEXT }], + Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative. + TEXT + ), + CompletionPrompt.new( + provider: "huggingface", + prompt_type: CompletionPrompt.prompt_types[:text], + messages: [<<~TEXT], + ### System: + Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative. + + ### User: + {{user_input}} + + ### Assistant: + Here is a StableDiffusion prompt: + TEXT + ), + ] + end + end + end +end diff --git a/spec/lib/modules/ai_bot/commands/image_command_spec.rb b/spec/lib/modules/ai_bot/commands/image_command_spec.rb index d94b4be3..1d780d32 100644 --- a/spec/lib/modules/ai_bot/commands/image_command_spec.rb +++ b/spec/lib/modules/ai_bot/commands/image_command_spec.rb @@ -1,5 +1,7 @@ #frozen_string_literal: true +require_relative "../../../../support/stable_difussion_stubs" + RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } @@ -13,16 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" - stub_request( - :post, - "https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image", - ) - .with do |request| - json = JSON.parse(request.body) - expect(json["text_prompts"][0]["text"]).to eq("a pink cow") - true - end - .to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json) + StableDiffusionStubs.new.stub_response("a pink cow", [image, image]) image = described_class.new(bot_user: bot_user, post: post, args: nil) diff --git a/spec/lib/modules/ai_helper/painter_spec.rb b/spec/lib/modules/ai_helper/painter_spec.rb new file mode 100644 index 00000000..ed865ca2 --- /dev/null +++ b/spec/lib/modules/ai_helper/painter_spec.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +require_relative "../../../support/openai_completions_inference_stubs" +require_relative "../../../support/stable_difussion_stubs" + +RSpec.describe DiscourseAi::AiHelper::Painter do + subject(:painter) { described_class.new } + + fab!(:user) { Fabricate(:user) } + + before do + SiteSetting.ai_stability_api_url = "https://api.stability.dev" + SiteSetting.ai_stability_api_key = "abc" + end + + describe "#commission_thumbnails" do + let(:artifacts) do + %w[ + iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg== + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC + iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC + ] + end + + let(:raw_content) do + "Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas." + end + + let(:expected_image_prompt) { <<~TEXT.strip } + Visualize a vibrant scene of an inkwell bursting, spreading colors across a blank canvas, + embodying words in tangible forms, symbolizing the rhythm and emotion evoked by poetry, + under the soft glow of a full moon. + TEXT + + it "returns 4 samples" do + expected_prompt = [ + { role: "system", content: <<~TEXT }, + Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative. + TEXT + { role: "user", content: raw_content }, + ] + + OpenAiCompletionsInferenceStubs.stub_response(expected_prompt, expected_image_prompt) + StableDiffusionStubs.new.stub_response(expected_image_prompt, artifacts) + + thumbnails = subject.commission_thumbnails(raw_content, user) + thumbnail_urls = Upload.last(4).map(&:short_url) + + expect(thumbnails).to contain_exactly(*thumbnail_urls) + end + end +end diff --git a/spec/support/stable_difussion_stubs.rb b/spec/support/stable_difussion_stubs.rb new file mode 100644 index 00000000..da32248c --- /dev/null +++ b/spec/support/stable_difussion_stubs.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +class StableDiffusionStubs + include RSpec::Matchers + + def stub_response(prompt, images) + artifacts = images.map { |i| { base64: i } } + + WebMock + .stub_request( + :post, + "https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image", + ) + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + expect(json[:text_prompts][0][:text]).to eq(prompt) + true + end + .to_return(status: 200, body: { artifacts: artifacts }.to_json) + end +end