mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FEATURE: DALL-E-3 persona for image generation (#311)
* FIX: no selected persona should pick first prioritized one Previously we were looking at `.personaId` but there is only an id attribute so it failed * FEATURE: new DALL-E-3 persona This persona generates images using DALL-E-3 API and is enabled by default Keep in mind that we are still waiting on seeds/gen_id so we can not retain style consistently between turns. This will change as soon as a new Open AI API provides the missing parameters Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
parent
6282b6d21f
commit
dff9f33a97
@ -36,11 +36,14 @@ export default class BotSelector extends Component {
|
||||
super(...arguments);
|
||||
|
||||
if (this.botOptions && this.composer) {
|
||||
const personaId = this.preferredPersonaStore.getObject("id");
|
||||
let personaId = this.preferredPersonaStore.getObject("id");
|
||||
|
||||
this._value = this.botOptions[0].id;
|
||||
if (personaId) {
|
||||
this._value = parseInt(personaId, 10);
|
||||
} else {
|
||||
this._value = this.botOptions[0].personaId;
|
||||
personaId = parseInt(personaId, 10);
|
||||
if (this.botOptions.any((bot) => bot.id === personaId)) {
|
||||
this._value = personaId;
|
||||
}
|
||||
}
|
||||
|
||||
this.composer.metaData = { ai_persona_id: this._value };
|
||||
|
@ -143,6 +143,9 @@ en:
|
||||
creative:
|
||||
name: Creative
|
||||
description: "AI Bot with no external integrations specialized in creative tasks"
|
||||
dall_e3:
|
||||
name: "DALL-E 3"
|
||||
description: "AI Bot specialized in generating images using DALL-E 3"
|
||||
topic_not_found: "Summary unavailable, topic not found!"
|
||||
searching: "Searching for: '%{query}'"
|
||||
command_summary:
|
||||
@ -157,10 +160,12 @@ en:
|
||||
setting_context: "Look up site setting context"
|
||||
schema: "Look up database schema"
|
||||
search_settings: "Searching site settings"
|
||||
dall_e: "Generate image"
|
||||
command_description:
|
||||
read: "Reading: <a href='%{url}'>%{title}</a>"
|
||||
time: "Time in %{timezone} is %{time}"
|
||||
summarize: "Summarized <a href='%{url}'>%{title}</a>"
|
||||
dall_e: "%{prompt}"
|
||||
image: "%{prompt}"
|
||||
categories:
|
||||
one: "Found %{count} category"
|
||||
|
116
lib/modules/ai_bot/commands/dall_e_command.rb
Normal file
116
lib/modules/ai_bot/commands/dall_e_command.rb
Normal file
@ -0,0 +1,116 @@
|
||||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi::AiBot::Commands
|
||||
class DallECommand < Command
|
||||
class << self
|
||||
def name
|
||||
"dall_e"
|
||||
end
|
||||
|
||||
def desc
|
||||
"Renders images from supplied descriptions"
|
||||
end
|
||||
|
||||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def result_name
|
||||
"results"
|
||||
end
|
||||
|
||||
def description_args
|
||||
{ prompt: @last_prompt }
|
||||
end
|
||||
|
||||
def chain_next_response
|
||||
false
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@custom_raw
|
||||
end
|
||||
|
||||
def process(prompts:)
|
||||
# max 4 prompts
|
||||
prompts = prompts.take(4)
|
||||
|
||||
@last_prompt = prompts[0]
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
results = nil
|
||||
|
||||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_openai_api_key
|
||||
|
||||
threads = []
|
||||
prompts.each_with_index do |prompt, index|
|
||||
threads << Thread.new(prompt) do |inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
retry if attempts < 3
|
||||
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true
|
||||
show_progress(".", progress_caret: true)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
results = threads.filter_map(&:value)
|
||||
|
||||
if results.blank?
|
||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results.each_with_index do |result, index|
|
||||
result[:data].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:b64_json]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: image[:revised_prompt],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@custom_raw = <<~RAW
|
||||
|
||||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{ prompts: uploads.map { |item| item[:prompt] } }
|
||||
end
|
||||
end
|
||||
end
|
@ -45,6 +45,7 @@ module DiscourseAi
|
||||
require_relative "commands/setting_context_command"
|
||||
require_relative "commands/search_settings_command"
|
||||
require_relative "commands/db_schema_command"
|
||||
require_relative "commands/dall_e_command"
|
||||
require_relative "personas/persona"
|
||||
require_relative "personas/artist"
|
||||
require_relative "personas/general"
|
||||
@ -52,6 +53,7 @@ module DiscourseAi
|
||||
require_relative "personas/settings_explorer"
|
||||
require_relative "personas/researcher"
|
||||
require_relative "personas/creative"
|
||||
require_relative "personas/dall_e_3"
|
||||
require_relative "site_settings_extension"
|
||||
end
|
||||
|
||||
|
35
lib/modules/ai_bot/personas/dall_e_3.rb
Normal file
35
lib/modules/ai_bot/personas/dall_e_3.rb
Normal file
@ -0,0 +1,35 @@
|
||||
#frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
module Personas
|
||||
class DallE3 < Persona
|
||||
def commands
|
||||
[Commands::DallECommand]
|
||||
end
|
||||
|
||||
def required_commands
|
||||
[Commands::DallECommand]
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
<<~PROMPT
|
||||
You are a bot specializing in generating images using DALL-E-3
|
||||
|
||||
- A good prompt needs to be detailed and specific.
|
||||
- You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it)
|
||||
- You can specify details about lighting or time of day.
|
||||
- You can specify a particular website you would like to emulate (artstation or deviantart)
|
||||
- You can specify additional details such as "beutiful, dystopian, futuristic, etc."
|
||||
- Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt
|
||||
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||
|
||||
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
|
||||
- Be creative with your prompts, offer diverse options
|
||||
- DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images
|
||||
PROMPT
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -11,6 +11,7 @@ module DiscourseAi
|
||||
Personas::SettingsExplorer => -4,
|
||||
Personas::Researcher => -5,
|
||||
Personas::Creative => -6,
|
||||
Personas::DallE3 => -7,
|
||||
}
|
||||
end
|
||||
|
||||
@ -145,6 +146,8 @@ module DiscourseAi
|
||||
|
||||
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
|
||||
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
||||
|
||||
all_commands << Commands::DallECommand if SiteSetting.ai_openai_api_key.present?
|
||||
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
||||
SiteSetting.ai_google_custom_search_cx.present?
|
||||
all_commands << Commands::GoogleCommand
|
||||
|
40
lib/shared/inference/openai_image_generator.rb
Normal file
40
lib/shared/inference/openai_image_generator.rb
Normal file
@ -0,0 +1,40 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class OpenAiImageGenerator
|
||||
TIMEOUT = 60
|
||||
|
||||
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil)
|
||||
api_key ||= SiteSetting.ai_openai_api_key
|
||||
|
||||
url = URI("https://api.openai.com/v1/images/generations")
|
||||
headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" }
|
||||
|
||||
payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" }
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: url.scheme == "https",
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url, headers)
|
||||
request.body = payload.to_json
|
||||
|
||||
json = nil
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
raise "OpenAI API returned #{response.code} #{response.body}"
|
||||
else
|
||||
json = JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
json
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -36,6 +36,7 @@ after_initialize do
|
||||
require_relative "lib/shared/inference/discourse_reranker"
|
||||
require_relative "lib/shared/inference/openai_completions"
|
||||
require_relative "lib/shared/inference/openai_embeddings"
|
||||
require_relative "lib/shared/inference/openai_image_generator"
|
||||
require_relative "lib/shared/inference/anthropic_completions"
|
||||
require_relative "lib/shared/inference/stability_generator"
|
||||
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||
|
40
spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb
Normal file
40
spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb
Normal file
@ -0,0 +1,40 @@
|
||||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
SiteSetting.ai_openai_api_key = "abc"
|
||||
|
||||
image =
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body, symbolize_names: true)
|
||||
expect(prompts).to include(json[:prompt])
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: { data: data }.to_json)
|
||||
|
||||
image = described_class.new(bot: bot, post: post, args: nil)
|
||||
|
||||
info = image.process(prompts: prompts).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow 1")
|
||||
end
|
||||
end
|
||||
end
|
@ -14,6 +14,11 @@ RSpec.describe "AI personas", type: :system, js: true do
|
||||
visit "/"
|
||||
find(".d-header .ai-bot-button").click()
|
||||
persona_selector = PageObjects::Components::SelectKit.new(".persona-selector__dropdown")
|
||||
|
||||
id = DiscourseAi::AiBot::Personas.all(user: admin).first.id
|
||||
|
||||
expect(persona_selector).to have_selected_value(id)
|
||||
|
||||
persona_selector.expand
|
||||
persona_selector.select_row_by_value(-2)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user