From dff9f33a97d9a8707dc432d31df922ee416ac4d1 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 24 Nov 2023 18:08:08 +1100 Subject: [PATCH] FEATURE: DALL-E-3 persona for image generation (#311) * FIX: no selected persona should pick first prioritized one Previously we were looking at `.personaId` but there is only an id attribute so it failed * FEATURE: new DALL-E-3 persona This persona generates images using DALL-E-3 API and is enabled by default Keep in mind that we are still waiting on seeds/gen_id so we can not retain style consistently between turns. This will change as soon as a new Open AI API provides the missing parameters Co-authored-by: Martin Brennan --- .../composer-fields/persona-selector.gjs | 11 +- config/locales/server.en.yml | 5 + lib/modules/ai_bot/commands/dall_e_command.rb | 116 ++++++++++++++++++ lib/modules/ai_bot/entry_point.rb | 2 + lib/modules/ai_bot/personas/dall_e_3.rb | 35 ++++++ lib/modules/ai_bot/personas/persona.rb | 3 + .../inference/openai_image_generator.rb | 40 ++++++ plugin.rb | 1 + .../ai_bot/commands/dall_e_command_spec.rb | 40 ++++++ spec/system/ai_bot/persona_spec.rb | 5 + 10 files changed, 254 insertions(+), 4 deletions(-) create mode 100644 lib/modules/ai_bot/commands/dall_e_command.rb create mode 100644 lib/modules/ai_bot/personas/dall_e_3.rb create mode 100644 lib/shared/inference/openai_image_generator.rb create mode 100644 spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb diff --git a/assets/javascripts/discourse/connectors/composer-fields/persona-selector.gjs b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.gjs index e15922b0..0209aa81 100644 --- a/assets/javascripts/discourse/connectors/composer-fields/persona-selector.gjs +++ b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.gjs @@ -36,11 +36,14 @@ export default class BotSelector extends Component { super(...arguments); if (this.botOptions && this.composer) { - const personaId = this.preferredPersonaStore.getObject("id"); + let personaId = this.preferredPersonaStore.getObject("id"); + + this._value = this.botOptions[0].id; if (personaId) { - this._value = parseInt(personaId, 10); - } else { - this._value = this.botOptions[0].personaId; + personaId = parseInt(personaId, 10); + if (this.botOptions.any((bot) => bot.id === personaId)) { + this._value = personaId; + } } this.composer.metaData = { ai_persona_id: this._value }; diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 3a38dea2..116b17c1 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -143,6 +143,9 @@ en: creative: name: Creative description: "AI Bot with no external integrations specialized in creative tasks" + dall_e3: + name: "DALL-E 3" + description: "AI Bot specialized in generating images using DALL-E 3" topic_not_found: "Summary unavailable, topic not found!" searching: "Searching for: '%{query}'" command_summary: @@ -157,10 +160,12 @@ en: setting_context: "Look up site setting context" schema: "Look up database schema" search_settings: "Searching site settings" + dall_e: "Generate image" command_description: read: "Reading: %{title}" time: "Time in %{timezone} is %{time}" summarize: "Summarized %{title}" + dall_e: "%{prompt}" image: "%{prompt}" categories: one: "Found %{count} category" diff --git a/lib/modules/ai_bot/commands/dall_e_command.rb b/lib/modules/ai_bot/commands/dall_e_command.rb new file mode 100644 index 00000000..a6ad8f2c --- /dev/null +++ b/lib/modules/ai_bot/commands/dall_e_command.rb @@ -0,0 +1,116 @@ +#frozen_string_literal: true + +module DiscourseAi::AiBot::Commands + class DallECommand < Command + class << self + def name + "dall_e" + end + + def desc + "Renders images from supplied descriptions" + end + + def parameters + [ + Parameter.new( + name: "prompts", + description: + "The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts", + type: "array", + item_type: "string", + required: true, + ), + ] + end + end + + def result_name + "results" + end + + def description_args + { prompt: @last_prompt } + end + + def chain_next_response + false + end + + def custom_raw + @custom_raw + end + + def process(prompts:) + # max 4 prompts + prompts = prompts.take(4) + + @last_prompt = prompts[0] + + show_progress(localized_description) + + results = nil + + # this ensures multisite safety since background threads + # generate the images + api_key = SiteSetting.ai_openai_api_key + + threads = [] + prompts.each_with_index do |prompt, index| + threads << Thread.new(prompt) do |inner_prompt| + attempts = 0 + begin + DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key) + rescue => e + attempts += 1 + retry if attempts < 3 + Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}") + nil + end + end + end + + while true + show_progress(".", progress_caret: true) + break if threads.all? { |t| t.join(2) } + end + + results = threads.filter_map(&:value) + + if results.blank? + return { prompts: prompts, error: "Something went wrong, could not generate image" } + end + + uploads = [] + + results.each_with_index do |result, index| + result[:data].each do |image| + Tempfile.create("v1_txt2img_#{index}.png") do |file| + file.binmode + file.write(Base64.decode64(image[:b64_json])) + file.rewind + uploads << { + prompt: image[:revised_prompt], + upload: UploadCreator.new(file, "image.png").create_for(bot_user.id), + } + end + end + end + + @custom_raw = <<~RAW + + [grid] + #{ + uploads + .map do |item| + "![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})" + end + .join(" ") + } + [/grid] + RAW + + { prompts: uploads.map { |item| item[:prompt] } } + end + end +end diff --git a/lib/modules/ai_bot/entry_point.rb b/lib/modules/ai_bot/entry_point.rb index e5f7eff3..6798ad56 100644 --- a/lib/modules/ai_bot/entry_point.rb +++ b/lib/modules/ai_bot/entry_point.rb @@ -45,6 +45,7 @@ module DiscourseAi require_relative "commands/setting_context_command" require_relative "commands/search_settings_command" require_relative "commands/db_schema_command" + require_relative "commands/dall_e_command" require_relative "personas/persona" require_relative "personas/artist" require_relative "personas/general" @@ -52,6 +53,7 @@ module DiscourseAi require_relative "personas/settings_explorer" require_relative "personas/researcher" require_relative "personas/creative" + require_relative "personas/dall_e_3" require_relative "site_settings_extension" end diff --git a/lib/modules/ai_bot/personas/dall_e_3.rb b/lib/modules/ai_bot/personas/dall_e_3.rb new file mode 100644 index 00000000..d490c5d0 --- /dev/null +++ b/lib/modules/ai_bot/personas/dall_e_3.rb @@ -0,0 +1,35 @@ +#frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + class DallE3 < Persona + def commands + [Commands::DallECommand] + end + + def required_commands + [Commands::DallECommand] + end + + def system_prompt + <<~PROMPT + You are a bot specializing in generating images using DALL-E-3 + + - A good prompt needs to be detailed and specific. + - You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it) + - You can specify details about lighting or time of day. + - You can specify a particular website you would like to emulate (artstation or deviantart) + - You can specify additional details such as "beutiful, dystopian, futuristic, etc." + - Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt + - You are extremely creative, when given short non descriptive prompts from a user you add your own details + + - When generating images, usually opt to generate 4 images unless the user specifies otherwise. + - Be creative with your prompts, offer diverse options + - DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images + PROMPT + end + end + end + end +end diff --git a/lib/modules/ai_bot/personas/persona.rb b/lib/modules/ai_bot/personas/persona.rb index 81754371..0ba7d48a 100644 --- a/lib/modules/ai_bot/personas/persona.rb +++ b/lib/modules/ai_bot/personas/persona.rb @@ -11,6 +11,7 @@ module DiscourseAi Personas::SettingsExplorer => -4, Personas::Researcher => -5, Personas::Creative => -6, + Personas::DallE3 => -7, } end @@ -145,6 +146,8 @@ module DiscourseAi all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present? + + all_commands << Commands::DallECommand if SiteSetting.ai_openai_api_key.present? if SiteSetting.ai_google_custom_search_api_key.present? && SiteSetting.ai_google_custom_search_cx.present? all_commands << Commands::GoogleCommand diff --git a/lib/shared/inference/openai_image_generator.rb b/lib/shared/inference/openai_image_generator.rb new file mode 100644 index 00000000..683ba016 --- /dev/null +++ b/lib/shared/inference/openai_image_generator.rb @@ -0,0 +1,40 @@ +# frozen_string_literal: true + +module ::DiscourseAi + module Inference + class OpenAiImageGenerator + TIMEOUT = 60 + + def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil) + api_key ||= SiteSetting.ai_openai_api_key + + url = URI("https://api.openai.com/v1/images/generations") + headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" } + + payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" } + + Net::HTTP.start( + url.host, + url.port, + use_ssl: url.scheme == "https", + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + request = Net::HTTP::Post.new(url, headers) + request.body = payload.to_json + + json = nil + http.request(request) do |response| + if response.code.to_i != 200 + raise "OpenAI API returned #{response.code} #{response.body}" + else + json = JSON.parse(response.body, symbolize_names: true) + end + end + json + end + end + end + end +end diff --git a/plugin.rb b/plugin.rb index a4022c62..e2e410c4 100644 --- a/plugin.rb +++ b/plugin.rb @@ -36,6 +36,7 @@ after_initialize do require_relative "lib/shared/inference/discourse_reranker" require_relative "lib/shared/inference/openai_completions" require_relative "lib/shared/inference/openai_embeddings" + require_relative "lib/shared/inference/openai_image_generator" require_relative "lib/shared/inference/anthropic_completions" require_relative "lib/shared/inference/stability_generator" require_relative "lib/shared/inference/hugging_face_text_generation" diff --git a/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb b/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb new file mode 100644 index 00000000..5165b9db --- /dev/null +++ b/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb @@ -0,0 +1,40 @@ +#frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do + let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } + let(:bot) { DiscourseAi::AiBot::OpenAiBot.new(bot_user) } + + before { SiteSetting.ai_bot_enabled = true } + + describe "#process" do + it "can generate correct info" do + post = Fabricate(:post) + + SiteSetting.ai_openai_api_key = "abc" + + image = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] + prompts = ["a pink cow", "a red cow"] + + WebMock + .stub_request(:post, "https://api.openai.com/v1/images/generations") + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + expect(prompts).to include(json[:prompt]) + true + end + .to_return(status: 200, body: { data: data }.to_json) + + image = described_class.new(bot: bot, post: post, args: nil) + + info = image.process(prompts: prompts).to_json + + expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) + expect(image.custom_raw).to include("upload://") + expect(image.custom_raw).to include("[grid]") + expect(image.custom_raw).to include("a pink cow 1") + end + end +end diff --git a/spec/system/ai_bot/persona_spec.rb b/spec/system/ai_bot/persona_spec.rb index e2a6b4ad..b0275146 100644 --- a/spec/system/ai_bot/persona_spec.rb +++ b/spec/system/ai_bot/persona_spec.rb @@ -14,6 +14,11 @@ RSpec.describe "AI personas", type: :system, js: true do visit "/" find(".d-header .ai-bot-button").click() persona_selector = PageObjects::Components::SelectKit.new(".persona-selector__dropdown") + + id = DiscourseAi::AiBot::Personas.all(user: admin).first.id + + expect(persona_selector).to have_selected_value(id) + persona_selector.expand persona_selector.select_row_by_value(-2)