mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
FEATURE: DALL-E-3 persona for image generation (#311)
* FIX: no selected persona should pick first prioritized one Previously we were looking at `.personaId` but there is only an id attribute so it failed * FEATURE: new DALL-E-3 persona This persona generates images using DALL-E-3 API and is enabled by default Keep in mind that we are still waiting on seeds/gen_id so we can not retain style consistently between turns. This will change as soon as a new Open AI API provides the missing parameters Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
parent
6282b6d21f
commit
dff9f33a97
@ -36,11 +36,14 @@ export default class BotSelector extends Component {
|
|||||||
super(...arguments);
|
super(...arguments);
|
||||||
|
|
||||||
if (this.botOptions && this.composer) {
|
if (this.botOptions && this.composer) {
|
||||||
const personaId = this.preferredPersonaStore.getObject("id");
|
let personaId = this.preferredPersonaStore.getObject("id");
|
||||||
|
|
||||||
|
this._value = this.botOptions[0].id;
|
||||||
if (personaId) {
|
if (personaId) {
|
||||||
this._value = parseInt(personaId, 10);
|
personaId = parseInt(personaId, 10);
|
||||||
} else {
|
if (this.botOptions.any((bot) => bot.id === personaId)) {
|
||||||
this._value = this.botOptions[0].personaId;
|
this._value = personaId;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.composer.metaData = { ai_persona_id: this._value };
|
this.composer.metaData = { ai_persona_id: this._value };
|
||||||
|
@ -143,6 +143,9 @@ en:
|
|||||||
creative:
|
creative:
|
||||||
name: Creative
|
name: Creative
|
||||||
description: "AI Bot with no external integrations specialized in creative tasks"
|
description: "AI Bot with no external integrations specialized in creative tasks"
|
||||||
|
dall_e3:
|
||||||
|
name: "DALL-E 3"
|
||||||
|
description: "AI Bot specialized in generating images using DALL-E 3"
|
||||||
topic_not_found: "Summary unavailable, topic not found!"
|
topic_not_found: "Summary unavailable, topic not found!"
|
||||||
searching: "Searching for: '%{query}'"
|
searching: "Searching for: '%{query}'"
|
||||||
command_summary:
|
command_summary:
|
||||||
@ -157,10 +160,12 @@ en:
|
|||||||
setting_context: "Look up site setting context"
|
setting_context: "Look up site setting context"
|
||||||
schema: "Look up database schema"
|
schema: "Look up database schema"
|
||||||
search_settings: "Searching site settings"
|
search_settings: "Searching site settings"
|
||||||
|
dall_e: "Generate image"
|
||||||
command_description:
|
command_description:
|
||||||
read: "Reading: <a href='%{url}'>%{title}</a>"
|
read: "Reading: <a href='%{url}'>%{title}</a>"
|
||||||
time: "Time in %{timezone} is %{time}"
|
time: "Time in %{timezone} is %{time}"
|
||||||
summarize: "Summarized <a href='%{url}'>%{title}</a>"
|
summarize: "Summarized <a href='%{url}'>%{title}</a>"
|
||||||
|
dall_e: "%{prompt}"
|
||||||
image: "%{prompt}"
|
image: "%{prompt}"
|
||||||
categories:
|
categories:
|
||||||
one: "Found %{count} category"
|
one: "Found %{count} category"
|
||||||
|
116
lib/modules/ai_bot/commands/dall_e_command.rb
Normal file
116
lib/modules/ai_bot/commands/dall_e_command.rb
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi::AiBot::Commands
|
||||||
|
class DallECommand < Command
|
||||||
|
class << self
|
||||||
|
def name
|
||||||
|
"dall_e"
|
||||||
|
end
|
||||||
|
|
||||||
|
def desc
|
||||||
|
"Renders images from supplied descriptions"
|
||||||
|
end
|
||||||
|
|
||||||
|
def parameters
|
||||||
|
[
|
||||||
|
Parameter.new(
|
||||||
|
name: "prompts",
|
||||||
|
description:
|
||||||
|
"The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
|
||||||
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
|
required: true,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def result_name
|
||||||
|
"results"
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ prompt: @last_prompt }
|
||||||
|
end
|
||||||
|
|
||||||
|
def chain_next_response
|
||||||
|
false
|
||||||
|
end
|
||||||
|
|
||||||
|
def custom_raw
|
||||||
|
@custom_raw
|
||||||
|
end
|
||||||
|
|
||||||
|
def process(prompts:)
|
||||||
|
# max 4 prompts
|
||||||
|
prompts = prompts.take(4)
|
||||||
|
|
||||||
|
@last_prompt = prompts[0]
|
||||||
|
|
||||||
|
show_progress(localized_description)
|
||||||
|
|
||||||
|
results = nil
|
||||||
|
|
||||||
|
# this ensures multisite safety since background threads
|
||||||
|
# generate the images
|
||||||
|
api_key = SiteSetting.ai_openai_api_key
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
prompts.each_with_index do |prompt, index|
|
||||||
|
threads << Thread.new(prompt) do |inner_prompt|
|
||||||
|
attempts = 0
|
||||||
|
begin
|
||||||
|
DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key)
|
||||||
|
rescue => e
|
||||||
|
attempts += 1
|
||||||
|
retry if attempts < 3
|
||||||
|
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
while true
|
||||||
|
show_progress(".", progress_caret: true)
|
||||||
|
break if threads.all? { |t| t.join(2) }
|
||||||
|
end
|
||||||
|
|
||||||
|
results = threads.filter_map(&:value)
|
||||||
|
|
||||||
|
if results.blank?
|
||||||
|
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||||
|
end
|
||||||
|
|
||||||
|
uploads = []
|
||||||
|
|
||||||
|
results.each_with_index do |result, index|
|
||||||
|
result[:data].each do |image|
|
||||||
|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||||
|
file.binmode
|
||||||
|
file.write(Base64.decode64(image[:b64_json]))
|
||||||
|
file.rewind
|
||||||
|
uploads << {
|
||||||
|
prompt: image[:revised_prompt],
|
||||||
|
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@custom_raw = <<~RAW
|
||||||
|
|
||||||
|
[grid]
|
||||||
|
#{
|
||||||
|
uploads
|
||||||
|
.map do |item|
|
||||||
|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||||
|
end
|
||||||
|
.join(" ")
|
||||||
|
}
|
||||||
|
[/grid]
|
||||||
|
RAW
|
||||||
|
|
||||||
|
{ prompts: uploads.map { |item| item[:prompt] } }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -45,6 +45,7 @@ module DiscourseAi
|
|||||||
require_relative "commands/setting_context_command"
|
require_relative "commands/setting_context_command"
|
||||||
require_relative "commands/search_settings_command"
|
require_relative "commands/search_settings_command"
|
||||||
require_relative "commands/db_schema_command"
|
require_relative "commands/db_schema_command"
|
||||||
|
require_relative "commands/dall_e_command"
|
||||||
require_relative "personas/persona"
|
require_relative "personas/persona"
|
||||||
require_relative "personas/artist"
|
require_relative "personas/artist"
|
||||||
require_relative "personas/general"
|
require_relative "personas/general"
|
||||||
@ -52,6 +53,7 @@ module DiscourseAi
|
|||||||
require_relative "personas/settings_explorer"
|
require_relative "personas/settings_explorer"
|
||||||
require_relative "personas/researcher"
|
require_relative "personas/researcher"
|
||||||
require_relative "personas/creative"
|
require_relative "personas/creative"
|
||||||
|
require_relative "personas/dall_e_3"
|
||||||
require_relative "site_settings_extension"
|
require_relative "site_settings_extension"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
35
lib/modules/ai_bot/personas/dall_e_3.rb
Normal file
35
lib/modules/ai_bot/personas/dall_e_3.rb
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
class DallE3 < Persona
|
||||||
|
def commands
|
||||||
|
[Commands::DallECommand]
|
||||||
|
end
|
||||||
|
|
||||||
|
def required_commands
|
||||||
|
[Commands::DallECommand]
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
You are a bot specializing in generating images using DALL-E-3
|
||||||
|
|
||||||
|
- A good prompt needs to be detailed and specific.
|
||||||
|
- You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it)
|
||||||
|
- You can specify details about lighting or time of day.
|
||||||
|
- You can specify a particular website you would like to emulate (artstation or deviantart)
|
||||||
|
- You can specify additional details such as "beutiful, dystopian, futuristic, etc."
|
||||||
|
- Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt
|
||||||
|
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||||
|
|
||||||
|
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
|
||||||
|
- Be creative with your prompts, offer diverse options
|
||||||
|
- DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -11,6 +11,7 @@ module DiscourseAi
|
|||||||
Personas::SettingsExplorer => -4,
|
Personas::SettingsExplorer => -4,
|
||||||
Personas::Researcher => -5,
|
Personas::Researcher => -5,
|
||||||
Personas::Creative => -6,
|
Personas::Creative => -6,
|
||||||
|
Personas::DallE3 => -7,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -145,6 +146,8 @@ module DiscourseAi
|
|||||||
|
|
||||||
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
|
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
|
||||||
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
||||||
|
|
||||||
|
all_commands << Commands::DallECommand if SiteSetting.ai_openai_api_key.present?
|
||||||
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
||||||
SiteSetting.ai_google_custom_search_cx.present?
|
SiteSetting.ai_google_custom_search_cx.present?
|
||||||
all_commands << Commands::GoogleCommand
|
all_commands << Commands::GoogleCommand
|
||||||
|
40
lib/shared/inference/openai_image_generator.rb
Normal file
40
lib/shared/inference/openai_image_generator.rb
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module ::DiscourseAi
|
||||||
|
module Inference
|
||||||
|
class OpenAiImageGenerator
|
||||||
|
TIMEOUT = 60
|
||||||
|
|
||||||
|
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil)
|
||||||
|
api_key ||= SiteSetting.ai_openai_api_key
|
||||||
|
|
||||||
|
url = URI("https://api.openai.com/v1/images/generations")
|
||||||
|
headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" }
|
||||||
|
|
||||||
|
payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" }
|
||||||
|
|
||||||
|
Net::HTTP.start(
|
||||||
|
url.host,
|
||||||
|
url.port,
|
||||||
|
use_ssl: url.scheme == "https",
|
||||||
|
read_timeout: TIMEOUT,
|
||||||
|
open_timeout: TIMEOUT,
|
||||||
|
write_timeout: TIMEOUT,
|
||||||
|
) do |http|
|
||||||
|
request = Net::HTTP::Post.new(url, headers)
|
||||||
|
request.body = payload.to_json
|
||||||
|
|
||||||
|
json = nil
|
||||||
|
http.request(request) do |response|
|
||||||
|
if response.code.to_i != 200
|
||||||
|
raise "OpenAI API returned #{response.code} #{response.body}"
|
||||||
|
else
|
||||||
|
json = JSON.parse(response.body, symbolize_names: true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
json
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -36,6 +36,7 @@ after_initialize do
|
|||||||
require_relative "lib/shared/inference/discourse_reranker"
|
require_relative "lib/shared/inference/discourse_reranker"
|
||||||
require_relative "lib/shared/inference/openai_completions"
|
require_relative "lib/shared/inference/openai_completions"
|
||||||
require_relative "lib/shared/inference/openai_embeddings"
|
require_relative "lib/shared/inference/openai_embeddings"
|
||||||
|
require_relative "lib/shared/inference/openai_image_generator"
|
||||||
require_relative "lib/shared/inference/anthropic_completions"
|
require_relative "lib/shared/inference/anthropic_completions"
|
||||||
require_relative "lib/shared/inference/stability_generator"
|
require_relative "lib/shared/inference/stability_generator"
|
||||||
require_relative "lib/shared/inference/hugging_face_text_generation"
|
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||||
|
40
spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb
Normal file
40
spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
||||||
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
|
let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
describe "#process" do
|
||||||
|
it "can generate correct info" do
|
||||||
|
post = Fabricate(:post)
|
||||||
|
|
||||||
|
SiteSetting.ai_openai_api_key = "abc"
|
||||||
|
|
||||||
|
image =
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||||
|
prompts = ["a pink cow", "a red cow"]
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body, symbolize_names: true)
|
||||||
|
expect(prompts).to include(json[:prompt])
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: { data: data }.to_json)
|
||||||
|
|
||||||
|
image = described_class.new(bot: bot, post: post, args: nil)
|
||||||
|
|
||||||
|
info = image.process(prompts: prompts).to_json
|
||||||
|
|
||||||
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||||
|
expect(image.custom_raw).to include("upload://")
|
||||||
|
expect(image.custom_raw).to include("[grid]")
|
||||||
|
expect(image.custom_raw).to include("a pink cow 1")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -14,6 +14,11 @@ RSpec.describe "AI personas", type: :system, js: true do
|
|||||||
visit "/"
|
visit "/"
|
||||||
find(".d-header .ai-bot-button").click()
|
find(".d-header .ai-bot-button").click()
|
||||||
persona_selector = PageObjects::Components::SelectKit.new(".persona-selector__dropdown")
|
persona_selector = PageObjects::Components::SelectKit.new(".persona-selector__dropdown")
|
||||||
|
|
||||||
|
id = DiscourseAi::AiBot::Personas.all(user: admin).first.id
|
||||||
|
|
||||||
|
expect(persona_selector).to have_selected_value(id)
|
||||||
|
|
||||||
persona_selector.expand
|
persona_selector.expand
|
||||||
persona_selector.select_row_by_value(-2)
|
persona_selector.select_row_by_value(-2)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user