FEATURE: Azure OpenAI support for DALL*E 3 (#313)

* FEATURE: Azure OpenAI support for DALL*E 3

Previous to this there was no way to add an inference endpoint for
DALL*E on Azure cause it requires custom URLs

Also:

- On save, when editing a persona it would revert priority and enabled
- More forgiving parsing in command framework for array function calls
- By default generate HD images - they tend to be a bit better
- Improve DALL*E prompt which was getting very annoying and always echoing what it is about to do
- Add a bit of a sleep between retries on image generation
- Fix error handling in image_command
This commit is contained in:
Sam 2023-11-27 13:01:05 +11:00 committed by GitHub
parent dff9f33a97
commit 5a4598a7b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 89 additions and 22 deletions

View File

@ -96,6 +96,7 @@ export default class PersonaEditor extends Component {
@action
async toggleEnabled() {
this.args.model.set("enabled", !this.args.model.enabled);
this.editingModel.set("enabled", this.args.model.enabled);
if (!this.args.model.isNew) {
try {
await this.args.model.update({ enabled: this.args.model.enabled });
@ -108,6 +109,7 @@ export default class PersonaEditor extends Component {
@action
async togglePriority() {
this.args.model.set("priority", !this.args.model.priority);
this.editingModel.set("priority", this.args.model.priority);
if (!this.args.model.isNew) {
try {
await this.args.model.update({ priority: this.args.model.priority });

View File

@ -41,6 +41,7 @@ en:
ai_openai_gpt35_16k_url: "Custom URL used for GPT 3.5 16k chat completions. (for Azure support)"
ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)"
ai_openai_gpt4_32k_url: "Custom URL used for GPT 4 32k chat completions. (for Azure support)"
ai_openai_dall_e_3_url: "Custom URL used for DALL-E 3 image generation. (for Azure support)"
ai_openai_organization: "(Optional, leave empty to omit) Organization id used for the OpenAI API. Passed in using the OpenAI-Organization header."
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
ai_openai_api_key: "API key for OpenAI API"

View File

@ -92,6 +92,7 @@ discourse_ai:
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
ai_openai_organization: ""
ai_openai_api_key:

View File

@ -54,15 +54,21 @@ module DiscourseAi::AiBot::Commands
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_openai_api_key
api_url = SiteSetting.ai_openai_dall_e_3_url
threads = []
prompts.each_with_index do |prompt, index|
threads << Thread.new(prompt) do |inner_prompt|
attempts = 0
begin
DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key)
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
inner_prompt,
api_key: api_key,
api_url: api_url,
)
rescue => e
attempts += 1
sleep 2
retry if attempts < 3
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
nil

View File

@ -96,7 +96,7 @@ module DiscourseAi::AiBot::Commands
results = threads.map(&:value).compact
if !results.present?
return { prompt: prompt, error: "Something went wrong, could not generate image" }
return { prompts: prompts, error: "Something went wrong, could not generate image" }
end
uploads = []

View File

@ -14,19 +14,23 @@ module DiscourseAi
def system_prompt
<<~PROMPT
You are a bot specializing in generating images using DALL-E-3
As a DALL-E-3 bot, you're tasked with generating images based on user prompts.
- A good prompt needs to be detailed and specific.
- You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it)
- You can specify details about lighting or time of day.
- You can specify a particular website you would like to emulate (artstation or deviantart)
- You can specify additional details such as "beutiful, dystopian, futuristic, etc."
- Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
- Be specific and detailed in your prompts. Include elements like subject, medium (e.g., oil on canvas), artist style, lighting, time of day, and website style (e.g., ArtStation, DeviantArt).
- Add adjectives for more detail (e.g., beautiful, dystopian, futuristic).
- Prompts should be 40-100 words long, but remember the API accepts a maximum of 5000 characters per prompt.
- Enhance short, vague user prompts with your own creative details.
- Unless specified, generate 4 images per prompt.
- Don't seek user permission before generating images or run the prompts by the user. Generate immediately to save tokens.
Example:
User: "a cow"
You: Generate images immediately, without telling the user anything. Details will be provided to user with the generated images.
DO NOT SAY "I will generate the following ... image 1 description ... image 2 description ... etc."
Just generate the images
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
- Be creative with your prompts, offer diverse options
- DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images
PROMPT
end
end

View File

@ -50,7 +50,12 @@ module ::DiscourseAi
type = parameter[:type]
if type == "array"
arguments[name] = JSON.parse(value)
begin
arguments[name] = JSON.parse(value)
rescue JSON::ParserError
# maybe LLM chose a different shape for the array
arguments[name] = value.to_s.split("\n").map(&:strip).reject(&:blank?)
end
elsif type == "integer"
arguments[name] = value.to_i
elsif type == "float"

View File

@ -5,23 +5,38 @@ module ::DiscourseAi
class OpenAiImageGenerator
TIMEOUT = 60
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil)
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil, api_url: nil)
api_key ||= SiteSetting.ai_openai_api_key
api_url ||= SiteSetting.ai_openai_dall_e_3_url
url = URI("https://api.openai.com/v1/images/generations")
headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" }
uri = URI(api_url)
payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" }
headers = { "Content-Type" => "application/json" }
if uri.host.include?("azure")
headers["api-key"] = api_key
else
headers["Authorization"] = "Bearer #{api_key}"
end
payload = {
quality: "hd",
model: model,
prompt: prompt,
n: 1,
size: size,
response_format: "b64_json",
}
Net::HTTP.start(
url.host,
url.port,
use_ssl: url.scheme == "https",
uri.host,
uri.port,
use_ssl: uri.scheme == "https",
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request = Net::HTTP::Post.new(uri, headers)
request.body = payload.to_json
json = nil

View File

@ -7,6 +7,39 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do
it "can generate correct info with azure" do
post = Fabricate(:post)
SiteSetting.ai_openai_api_key = "abc"
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
prompts = ["a pink cow", "a red cow"]
WebMock
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
expect(prompts).to include(json[:prompt])
expect(request.headers["Api-Key"]).to eq("abc")
true
end
.to_return(status: 200, body: { data: data }.to_json)
image = described_class.new(bot: bot, post: post, args: nil)
info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
expect(image.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow 1")
end
it "can generate correct info" do
post = Fabricate(:post)