FEATURE: AI Helper endpoint to generate a thumbnail from text. (#224)
We pass the text to the current LLM and ask them to generate a StableDifussion prompt. We'll use that to generate 4 samples, temporarily creating uploads and returning their short URLs.
This commit is contained in:
parent
1eb70c4f0a
commit
f57c1bb0f6
|
@ -6,6 +6,7 @@ module DiscourseAi
|
||||||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||||
requires_login
|
requires_login
|
||||||
before_action :ensure_can_request_suggestions
|
before_action :ensure_can_request_suggestions
|
||||||
|
before_action :rate_limiter_performed!, except: %i[prompts]
|
||||||
|
|
||||||
def prompts
|
def prompts
|
||||||
render json:
|
render json:
|
||||||
|
@ -17,19 +18,13 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def suggest
|
def suggest
|
||||||
raise Discourse::InvalidParameters.new(:text) if params[:text].blank?
|
input = get_text_param!
|
||||||
|
|
||||||
prompt = CompletionPrompt.find_by(id: params[:mode])
|
prompt = CompletionPrompt.find_by(id: params[:mode])
|
||||||
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
|
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
|
||||||
|
|
||||||
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
|
|
||||||
|
|
||||||
hijack do
|
hijack do
|
||||||
render json:
|
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
|
||||||
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
|
|
||||||
prompt,
|
|
||||||
params[:text],
|
|
||||||
),
|
|
||||||
status: 200
|
status: 200
|
||||||
end
|
end
|
||||||
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
|
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
|
||||||
|
@ -40,7 +35,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def suggest_title
|
def suggest_title
|
||||||
raise Discourse::InvalidParameters.new(:text) if params[:text].blank?
|
input = get_text_param!
|
||||||
|
|
||||||
llm_prompt =
|
llm_prompt =
|
||||||
DiscourseAi::AiHelper::LlmPrompt
|
DiscourseAi::AiHelper::LlmPrompt
|
||||||
|
@ -50,14 +45,8 @@ module DiscourseAi
|
||||||
prompt = CompletionPrompt.find_by(id: llm_prompt[:id])
|
prompt = CompletionPrompt.find_by(id: llm_prompt[:id])
|
||||||
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
|
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
|
||||||
|
|
||||||
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
|
|
||||||
|
|
||||||
hijack do
|
hijack do
|
||||||
render json:
|
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
|
||||||
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
|
|
||||||
prompt,
|
|
||||||
params[:text],
|
|
||||||
),
|
|
||||||
status: 200
|
status: 200
|
||||||
end
|
end
|
||||||
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
|
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
|
||||||
|
@ -68,30 +57,39 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def suggest_category
|
def suggest_category
|
||||||
raise Discourse::InvalidParameters.new(:text) if params[:text].blank?
|
input = get_text_param!
|
||||||
|
|
||||||
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
|
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories,
|
||||||
|
|
||||||
render json:
|
|
||||||
DiscourseAi::AiHelper::SemanticCategorizer.new(
|
|
||||||
params[:text],
|
|
||||||
current_user,
|
|
||||||
).categories,
|
|
||||||
status: 200
|
status: 200
|
||||||
end
|
end
|
||||||
|
|
||||||
def suggest_tags
|
def suggest_tags
|
||||||
raise Discourse::InvalidParameters.new(:text) if params[:text].blank?
|
input = get_text_param!
|
||||||
|
|
||||||
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
|
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags,
|
||||||
|
|
||||||
render json:
|
|
||||||
DiscourseAi::AiHelper::SemanticCategorizer.new(params[:text], current_user).tags,
|
|
||||||
status: 200
|
status: 200
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def suggest_thumbnails
|
||||||
|
input = get_text_param!
|
||||||
|
|
||||||
|
hijack do
|
||||||
|
thumbnails = DiscourseAi::AiHelper::Painter.new.commission_thumbnails(input, current_user)
|
||||||
|
|
||||||
|
render json: { thumbnails: thumbnails }, status: 200
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
def get_text_param!
|
||||||
|
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
|
||||||
|
end
|
||||||
|
|
||||||
|
def rate_limiter_performed!
|
||||||
|
RateLimiter.new(current_user, "ai_assistant", 6, 3.minutes).performed!
|
||||||
|
end
|
||||||
|
|
||||||
def ensure_can_request_suggestions
|
def ensure_can_request_suggestions
|
||||||
user_group_ids = current_user.group_ids
|
user_group_ids = current_user.group_ids
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ DiscourseAi::Engine.routes.draw do
|
||||||
post "suggest_title" => "assistant#suggest_title"
|
post "suggest_title" => "assistant#suggest_title"
|
||||||
post "suggest_category" => "assistant#suggest_category"
|
post "suggest_category" => "assistant#suggest_category"
|
||||||
post "suggest_tags" => "assistant#suggest_tags"
|
post "suggest_tags" => "assistant#suggest_tags"
|
||||||
|
post "suggest_thumbnails" => "assistant#suggest_thumbnails"
|
||||||
end
|
end
|
||||||
|
|
||||||
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
||||||
|
|
|
@ -5,6 +5,7 @@ module DiscourseAi
|
||||||
def load_files
|
def load_files
|
||||||
require_relative "llm_prompt"
|
require_relative "llm_prompt"
|
||||||
require_relative "semantic_categorizer"
|
require_relative "semantic_categorizer"
|
||||||
|
require_relative "painter"
|
||||||
end
|
end
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiHelper
|
||||||
|
class Painter
|
||||||
|
def commission_thumbnails(theme, user)
|
||||||
|
stable_diffusion_prompt = difussion_prompt(theme)
|
||||||
|
|
||||||
|
return [] if stable_diffusion_prompt.blank?
|
||||||
|
|
||||||
|
base64_artifacts =
|
||||||
|
DiscourseAi::Inference::StabilityGenerator
|
||||||
|
.perform!(stable_diffusion_prompt)
|
||||||
|
.dig(:artifacts)
|
||||||
|
.to_a
|
||||||
|
.map { |art| art[:base64] }
|
||||||
|
|
||||||
|
base64_artifacts.each_with_index.map do |artifact, i|
|
||||||
|
f = Tempfile.new("v1_txt2img_#{i}.png")
|
||||||
|
f.binmode
|
||||||
|
f.write(Base64.decode64(artifact))
|
||||||
|
f.rewind
|
||||||
|
upload = UploadCreator.new(f, "ai_helper_image.png").create_for(user.id)
|
||||||
|
f.unlink
|
||||||
|
|
||||||
|
upload.short_url
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def difussion_prompt(text)
|
||||||
|
llm_prompt = LlmPrompt.new
|
||||||
|
prompt_for_provider =
|
||||||
|
completion_prompts.find { |prompt| prompt.provider == llm_prompt.enabled_provider }
|
||||||
|
|
||||||
|
return "" if prompt_for_provider.nil?
|
||||||
|
|
||||||
|
llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first
|
||||||
|
end
|
||||||
|
|
||||||
|
def completion_prompts
|
||||||
|
[
|
||||||
|
CompletionPrompt.new(
|
||||||
|
provider: "anthropic",
|
||||||
|
prompt_type: CompletionPrompt.prompt_types[:text],
|
||||||
|
messages: [{ role: "Human", content: <<~TEXT }],
|
||||||
|
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||||
|
The post is provided between <input> tags and the Stable Diffusion prompt string should be returned between <ai> tags.
|
||||||
|
TEXT
|
||||||
|
),
|
||||||
|
CompletionPrompt.new(
|
||||||
|
provider: "openai",
|
||||||
|
prompt_type: CompletionPrompt.prompt_types[:text],
|
||||||
|
messages: [{ role: "system", content: <<~TEXT }],
|
||||||
|
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||||
|
TEXT
|
||||||
|
),
|
||||||
|
CompletionPrompt.new(
|
||||||
|
provider: "huggingface",
|
||||||
|
prompt_type: CompletionPrompt.prompt_types[:text],
|
||||||
|
messages: [<<~TEXT],
|
||||||
|
### System:
|
||||||
|
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||||
|
|
||||||
|
### User:
|
||||||
|
{{user_input}}
|
||||||
|
|
||||||
|
### Assistant:
|
||||||
|
Here is a StableDiffusion prompt:
|
||||||
|
TEXT
|
||||||
|
),
|
||||||
|
]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,5 +1,7 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../../support/stable_difussion_stubs"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
|
|
||||||
|
@ -13,16 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||||
image =
|
image =
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
stub_request(
|
StableDiffusionStubs.new.stub_response("a pink cow", [image, image])
|
||||||
:post,
|
|
||||||
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
|
|
||||||
)
|
|
||||||
.with do |request|
|
|
||||||
json = JSON.parse(request.body)
|
|
||||||
expect(json["text_prompts"][0]["text"]).to eq("a pink cow")
|
|
||||||
true
|
|
||||||
end
|
|
||||||
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
|
|
||||||
|
|
||||||
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../support/openai_completions_inference_stubs"
|
||||||
|
require_relative "../../../support/stable_difussion_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiHelper::Painter do
|
||||||
|
subject(:painter) { described_class.new }
|
||||||
|
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
||||||
|
SiteSetting.ai_stability_api_key = "abc"
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#commission_thumbnails" do
|
||||||
|
let(:artifacts) do
|
||||||
|
%w[
|
||||||
|
iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==
|
||||||
|
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8z8BQz0AEYBxVSF+FABJADveWkH6oAAAAAElFTkSuQmCC
|
||||||
|
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC
|
||||||
|
iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:raw_content) do
|
||||||
|
"Poetry is a form of artistic expression that uses language aesthetically and rhythmically to evoke emotions and ideas."
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:expected_image_prompt) { <<~TEXT.strip }
|
||||||
|
Visualize a vibrant scene of an inkwell bursting, spreading colors across a blank canvas,
|
||||||
|
embodying words in tangible forms, symbolizing the rhythm and emotion evoked by poetry,
|
||||||
|
under the soft glow of a full moon.
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
it "returns 4 samples" do
|
||||||
|
expected_prompt = [
|
||||||
|
{ role: "system", content: <<~TEXT },
|
||||||
|
Provide me a StableDiffusion prompt to generate an image that illustrates the following post in 40 words or less, be creative.
|
||||||
|
TEXT
|
||||||
|
{ role: "user", content: raw_content },
|
||||||
|
]
|
||||||
|
|
||||||
|
OpenAiCompletionsInferenceStubs.stub_response(expected_prompt, expected_image_prompt)
|
||||||
|
StableDiffusionStubs.new.stub_response(expected_image_prompt, artifacts)
|
||||||
|
|
||||||
|
thumbnails = subject.commission_thumbnails(raw_content, user)
|
||||||
|
thumbnail_urls = Upload.last(4).map(&:short_url)
|
||||||
|
|
||||||
|
expect(thumbnails).to contain_exactly(*thumbnail_urls)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,21 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class StableDiffusionStubs
|
||||||
|
include RSpec::Matchers
|
||||||
|
|
||||||
|
def stub_response(prompt, images)
|
||||||
|
artifacts = images.map { |i| { base64: i } }
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(
|
||||||
|
:post,
|
||||||
|
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
|
||||||
|
)
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body, symbolize_names: true)
|
||||||
|
expect(json[:text_prompts][0][:text]).to eq(prompt)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue