From 5a4598a7b468c92150f0765f4aefcaa6aec99276 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 27 Nov 2023 13:01:05 +1100 Subject: [PATCH] FEATURE: Azure OpenAI support for DALL*E 3 (#313) * FEATURE: Azure OpenAI support for DALL*E 3 Previous to this there was no way to add an inference endpoint for DALL*E on Azure cause it requires custom URLs Also: - On save, when editing a persona it would revert priority and enabled - More forgiving parsing in command framework for array function calls - By default generate HD images - they tend to be a bit better - Improve DALL*E prompt which was getting very annoying and always echoing what it is about to do - Add a bit of a sleep between retries on image generation - Fix error handling in image_command --- .../components/ai-persona-editor.gjs | 2 ++ config/locales/server.en.yml | 1 + config/settings.yml | 1 + lib/modules/ai_bot/commands/dall_e_command.rb | 8 ++++- lib/modules/ai_bot/commands/image_command.rb | 2 +- lib/modules/ai_bot/personas/dall_e_3.rb | 26 ++++++++------- lib/shared/inference/function_list.rb | 7 +++- .../inference/openai_image_generator.rb | 31 ++++++++++++----- .../ai_bot/commands/dall_e_command_spec.rb | 33 +++++++++++++++++++ 9 files changed, 89 insertions(+), 22 deletions(-) diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 05960f28..47651ed2 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -96,6 +96,7 @@ export default class PersonaEditor extends Component { @action async toggleEnabled() { this.args.model.set("enabled", !this.args.model.enabled); + this.editingModel.set("enabled", this.args.model.enabled); if (!this.args.model.isNew) { try { await this.args.model.update({ enabled: this.args.model.enabled }); @@ -108,6 +109,7 @@ export default class PersonaEditor extends Component { @action async togglePriority() { this.args.model.set("priority", !this.args.model.priority); + this.editingModel.set("priority", this.args.model.priority); if (!this.args.model.isNew) { try { await this.args.model.update({ priority: this.args.model.priority }); diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 116b17c1..f7357ff4 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -41,6 +41,7 @@ en: ai_openai_gpt35_16k_url: "Custom URL used for GPT 3.5 16k chat completions. (for Azure support)" ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)" ai_openai_gpt4_32k_url: "Custom URL used for GPT 4 32k chat completions. (for Azure support)" + ai_openai_dall_e_3_url: "Custom URL used for DALL-E 3 image generation. (for Azure support)" ai_openai_organization: "(Optional, leave empty to omit) Organization id used for the OpenAI API. Passed in using the OpenAI-Organization header." ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" diff --git a/config/settings.yml b/config/settings.yml index 49b43be1..3b78e2d9 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -92,6 +92,7 @@ discourse_ai: ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions" + ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations" ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings" ai_openai_organization: "" ai_openai_api_key: diff --git a/lib/modules/ai_bot/commands/dall_e_command.rb b/lib/modules/ai_bot/commands/dall_e_command.rb index a6ad8f2c..8d92398a 100644 --- a/lib/modules/ai_bot/commands/dall_e_command.rb +++ b/lib/modules/ai_bot/commands/dall_e_command.rb @@ -54,15 +54,21 @@ module DiscourseAi::AiBot::Commands # this ensures multisite safety since background threads # generate the images api_key = SiteSetting.ai_openai_api_key + api_url = SiteSetting.ai_openai_dall_e_3_url threads = [] prompts.each_with_index do |prompt, index| threads << Thread.new(prompt) do |inner_prompt| attempts = 0 begin - DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key) + DiscourseAi::Inference::OpenAiImageGenerator.perform!( + inner_prompt, + api_key: api_key, + api_url: api_url, + ) rescue => e attempts += 1 + sleep 2 retry if attempts < 3 Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}") nil diff --git a/lib/modules/ai_bot/commands/image_command.rb b/lib/modules/ai_bot/commands/image_command.rb index a170cef8..cedf9208 100644 --- a/lib/modules/ai_bot/commands/image_command.rb +++ b/lib/modules/ai_bot/commands/image_command.rb @@ -96,7 +96,7 @@ module DiscourseAi::AiBot::Commands results = threads.map(&:value).compact if !results.present? - return { prompt: prompt, error: "Something went wrong, could not generate image" } + return { prompts: prompts, error: "Something went wrong, could not generate image" } end uploads = [] diff --git a/lib/modules/ai_bot/personas/dall_e_3.rb b/lib/modules/ai_bot/personas/dall_e_3.rb index d490c5d0..75fb035f 100644 --- a/lib/modules/ai_bot/personas/dall_e_3.rb +++ b/lib/modules/ai_bot/personas/dall_e_3.rb @@ -14,19 +14,23 @@ module DiscourseAi def system_prompt <<~PROMPT - You are a bot specializing in generating images using DALL-E-3 + As a DALL-E-3 bot, you're tasked with generating images based on user prompts. - - A good prompt needs to be detailed and specific. - - You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it) - - You can specify details about lighting or time of day. - - You can specify a particular website you would like to emulate (artstation or deviantart) - - You can specify additional details such as "beutiful, dystopian, futuristic, etc." - - Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt - - You are extremely creative, when given short non descriptive prompts from a user you add your own details + - Be specific and detailed in your prompts. Include elements like subject, medium (e.g., oil on canvas), artist style, lighting, time of day, and website style (e.g., ArtStation, DeviantArt). + - Add adjectives for more detail (e.g., beautiful, dystopian, futuristic). + - Prompts should be 40-100 words long, but remember the API accepts a maximum of 5000 characters per prompt. + - Enhance short, vague user prompts with your own creative details. + - Unless specified, generate 4 images per prompt. + - Don't seek user permission before generating images or run the prompts by the user. Generate immediately to save tokens. + + Example: + + User: "a cow" + You: Generate images immediately, without telling the user anything. Details will be provided to user with the generated images. + + DO NOT SAY "I will generate the following ... image 1 description ... image 2 description ... etc." + Just generate the images - - When generating images, usually opt to generate 4 images unless the user specifies otherwise. - - Be creative with your prompts, offer diverse options - - DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images PROMPT end end diff --git a/lib/shared/inference/function_list.rb b/lib/shared/inference/function_list.rb index 71b77efa..713565a7 100644 --- a/lib/shared/inference/function_list.rb +++ b/lib/shared/inference/function_list.rb @@ -50,7 +50,12 @@ module ::DiscourseAi type = parameter[:type] if type == "array" - arguments[name] = JSON.parse(value) + begin + arguments[name] = JSON.parse(value) + rescue JSON::ParserError + # maybe LLM chose a different shape for the array + arguments[name] = value.to_s.split("\n").map(&:strip).reject(&:blank?) + end elsif type == "integer" arguments[name] = value.to_i elsif type == "float" diff --git a/lib/shared/inference/openai_image_generator.rb b/lib/shared/inference/openai_image_generator.rb index 683ba016..f8cc203b 100644 --- a/lib/shared/inference/openai_image_generator.rb +++ b/lib/shared/inference/openai_image_generator.rb @@ -5,23 +5,38 @@ module ::DiscourseAi class OpenAiImageGenerator TIMEOUT = 60 - def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil) + def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil, api_url: nil) api_key ||= SiteSetting.ai_openai_api_key + api_url ||= SiteSetting.ai_openai_dall_e_3_url - url = URI("https://api.openai.com/v1/images/generations") - headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" } + uri = URI(api_url) - payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" } + headers = { "Content-Type" => "application/json" } + + if uri.host.include?("azure") + headers["api-key"] = api_key + else + headers["Authorization"] = "Bearer #{api_key}" + end + + payload = { + quality: "hd", + model: model, + prompt: prompt, + n: 1, + size: size, + response_format: "b64_json", + } Net::HTTP.start( - url.host, - url.port, - use_ssl: url.scheme == "https", + uri.host, + uri.port, + use_ssl: uri.scheme == "https", read_timeout: TIMEOUT, open_timeout: TIMEOUT, write_timeout: TIMEOUT, ) do |http| - request = Net::HTTP::Post.new(url, headers) + request = Net::HTTP::Post.new(uri, headers) request.body = payload.to_json json = nil diff --git a/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb b/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb index 5165b9db..ac22c7a1 100644 --- a/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb +++ b/spec/lib/modules/ai_bot/commands/dall_e_command_spec.rb @@ -7,6 +7,39 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do before { SiteSetting.ai_bot_enabled = true } describe "#process" do + it "can generate correct info with azure" do + post = Fabricate(:post) + + SiteSetting.ai_openai_api_key = "abc" + SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url" + + image = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] + prompts = ["a pink cow", "a red cow"] + + WebMock + .stub_request(:post, SiteSetting.ai_openai_dall_e_3_url) + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + + expect(prompts).to include(json[:prompt]) + expect(request.headers["Api-Key"]).to eq("abc") + true + end + .to_return(status: 200, body: { data: data }.to_json) + + image = described_class.new(bot: bot, post: post, args: nil) + + info = image.process(prompts: prompts).to_json + + expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"]) + expect(image.custom_raw).to include("upload://") + expect(image.custom_raw).to include("[grid]") + expect(image.custom_raw).to include("a pink cow 1") + end + it "can generate correct info" do post = Fabricate(:post)