FEATURE: AI Helper endpoint to generate a thumbnail from text. (#224)

We pass the text to the current LLM and ask them to generate a StableDifussion prompt.
We'll use that to generate 4 samples, temporarily creating uploads and returning their short URLs.
This commit is contained in:
Roman Rizzi 2023-09-14 12:53:44 -03:00 committed by GitHub
parent 1eb70c4f0a
commit f57c1bb0f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 183 additions and 39 deletions

View File

@ -6,6 +6,7 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login requires_login
before_action :ensure_can_request_suggestions before_action :ensure_can_request_suggestions
before_action :rate_limiter_performed!, except: %i[prompts]
def prompts def prompts
render json: render json:
@ -17,19 +18,13 @@ module DiscourseAi
end end
def suggest def suggest
raise Discourse::InvalidParameters.new(:text) if params[:text].blank? input = get_text_param!
prompt = CompletionPrompt.find_by(id: params[:mode]) prompt = CompletionPrompt.find_by(id: params[:mode])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
hijack do hijack do
render json: render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
prompt,
params[:text],
),
status: 200 status: 200
end end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -40,7 +35,7 @@ module DiscourseAi
end end
def suggest_title def suggest_title
raise Discourse::InvalidParameters.new(:text) if params[:text].blank? input = get_text_param!
llm_prompt = llm_prompt =
DiscourseAi::AiHelper::LlmPrompt DiscourseAi::AiHelper::LlmPrompt
@ -50,14 +45,8 @@ module DiscourseAi
prompt = CompletionPrompt.find_by(id: llm_prompt[:id]) prompt = CompletionPrompt.find_by(id: llm_prompt[:id])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
hijack do hijack do
render json: render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
prompt,
params[:text],
),
status: 200 status: 200
end end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -68,30 +57,39 @@ module DiscourseAi
end end
def suggest_category def suggest_category
raise Discourse::InvalidParameters.new(:text) if params[:text].blank? input = get_text_param!
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories,
render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(
params[:text],
current_user,
).categories,
status: 200 status: 200
end end
def suggest_tags def suggest_tags
raise Discourse::InvalidParameters.new(:text) if params[:text].blank? input = get_text_param!
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed! render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags,
render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(params[:text], current_user).tags,
status: 200 status: 200
end end
def suggest_thumbnails
input = get_text_param!
hijack do
thumbnails = DiscourseAi::AiHelper::Painter.new.commission_thumbnails(input, current_user)
render json: { thumbnails: thumbnails }, status: 200
end
end
private private
def get_text_param!
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
end
def rate_limiter_performed!
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
end
def ensure_can_request_suggestions def ensure_can_request_suggestions
user_group_ids = current_user.group_ids user_group_ids = current_user.group_ids

View File

@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do
post "suggest_title" => "assistant#suggest_title" post "suggest_title" => "assistant#suggest_title"
post "suggest_category" => "assistant#suggest_category" post "suggest_category" => "assistant#suggest_category"
post "suggest_tags" => "assistant#suggest_tags" post "suggest_tags" => "assistant#suggest_tags"
post "suggest_thumbnails" => "assistant#suggest_thumbnails"
end end
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do

View File

@ -5,6 +5,7 @@ module DiscourseAi
def load_files def load_files
require_relative "llm_prompt" require_relative "llm_prompt"
require_relative "semantic_categorizer" require_relative "semantic_categorizer"
require_relative "painter"
end end
def inject_into(plugin) def inject_into(plugin)

View File

@ -0,0 +1,77 @@
# frozen_string_literal: true
module DiscourseAi
module AiHelper
class Painter
def commission_thumbnails(theme, user)
stable_diffusion_prompt = difussion_prompt(theme)
return [] if stable_diffusion_prompt.blank?
base64_artifacts =
DiscourseAi::Inference::StabilityGenerator
.perform!(stable_diffusion_prompt)
.dig(:artifacts)
.to_a
.map { |art| art[:base64] }
base64_artifacts.each_with_index.map do |artifact, i|
f = Tempfile.new("v1_txt2img_#{i}.png")
f.binmode
f.write(Base64.decode64(artifact))
f.rewind
upload = UploadCreator.new(f, "ai_helper_image.png").create_for(user.id)
f.unlink
upload.short_url
end
end
private
def difussion_prompt(text)
llm_prompt = LlmPrompt.new
prompt_for_provider =
completion_prompts.find { |prompt| prompt.provider == llm_prompt.enabled_provider }
return "" if prompt_for_provider.nil?
llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first
end
def completion_prompts
[
CompletionPrompt.new(
provider: "anthropic",
prompt_type: CompletionPrompt.prompt_types[:text],
messages: [{ role: "Human", content: <<~TEXT }],
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
The post is provided between <input> tags and the Stable Diffusion prompt string should be returned between <ai> tags.
TEXT
),
CompletionPrompt.new(
provider: "openai",
prompt_type: CompletionPrompt.prompt_types[:text],
messages: [{ role: "system", content: <<~TEXT }],
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
TEXT
),
CompletionPrompt.new(
provider: "huggingface",
prompt_type: CompletionPrompt.prompt_types[:text],
messages: [<<~TEXT],
### System:
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
### User:
{{user_input}}
### Assistant:
Here is a StableDiffusion prompt:
TEXT
),
]
end
end
end
end

View File

@ -1,5 +1,7 @@
#frozen_string_literal: true #frozen_string_literal: true
require_relative "../../../../support/stable_difussion_stubs"
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
@ -13,16 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
image = image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
stub_request( StableDiffusionStubs.new.stub_response("a pink cow", [image, image])
:post,
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
)
.with do |request|
json = JSON.parse(request.body)
expect(json["text_prompts"][0]["text"]).to eq("a pink cow")
true
end
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
image = described_class.new(bot_user: bot_user, post: post, args: nil) image = described_class.new(bot_user: bot_user, post: post, args: nil)

View File

@ -0,0 +1,53 @@
# frozen_string_literal: true
require_relative "../../../support/openai_completions_inference_stubs"
require_relative "../../../support/stable_difussion_stubs"
RSpec.describe DiscourseAi::AiHelper::Painter do
subject(:painter) { described_class.new }
fab!(:user) { Fabricate(:user) }
before do
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
SiteSetting.ai_stability_api_key = "abc"
end
describe "#commission_thumbnails" do
let(:artifacts) do
%w[
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC
]
end
let(:raw_content) do
"Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas."
end
let(:expected_image_prompt) { <<~TEXT.strip }
Visualize a vibrant scene of an inkwell bursting, spreading colors across a blank canvas,
embodying words in tangible forms, symbolizing the rhythm and emotion evoked by poetry,
under the soft glow of a full moon.
TEXT
it "returns 4 samples" do
expected_prompt = [
{ role: "system", content: <<~TEXT },
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
TEXT
{ role: "user", content: raw_content },
]
OpenAiCompletionsInferenceStubs.stub_response(expected_prompt, expected_image_prompt)
StableDiffusionStubs.new.stub_response(expected_image_prompt, artifacts)
thumbnails = subject.commission_thumbnails(raw_content, user)
thumbnail_urls = Upload.last(4).map(&:short_url)
expect(thumbnails).to contain_exactly(*thumbnail_urls)
end
end
end

View File

@ -0,0 +1,21 @@
# frozen_string_literal: true
class StableDiffusionStubs
include RSpec::Matchers
def stub_response(prompt, images)
artifacts = images.map { |i| { base64: i } }
WebMock
.stub_request(
:post,
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
)
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
expect(json[:text_prompts][0][:text]).to eq(prompt)
true
end
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
end
end