FEATURE: Azure OpenAI support for DALL*E 3 (#313)
* FEATURE: Azure OpenAI support for DALL*E 3 Previous to this there was no way to add an inference endpoint for DALL*E on Azure cause it requires custom URLs Also: - On save, when editing a persona it would revert priority and enabled - More forgiving parsing in command framework for array function calls - By default generate HD images - they tend to be a bit better - Improve DALL*E prompt which was getting very annoying and always echoing what it is about to do - Add a bit of a sleep between retries on image generation - Fix error handling in image_command
This commit is contained in:
parent
dff9f33a97
commit
5a4598a7b4
|
@ -96,6 +96,7 @@ export default class PersonaEditor extends Component {
|
|||
@action
|
||||
async toggleEnabled() {
|
||||
this.args.model.set("enabled", !this.args.model.enabled);
|
||||
this.editingModel.set("enabled", this.args.model.enabled);
|
||||
if (!this.args.model.isNew) {
|
||||
try {
|
||||
await this.args.model.update({ enabled: this.args.model.enabled });
|
||||
|
@ -108,6 +109,7 @@ export default class PersonaEditor extends Component {
|
|||
@action
|
||||
async togglePriority() {
|
||||
this.args.model.set("priority", !this.args.model.priority);
|
||||
this.editingModel.set("priority", this.args.model.priority);
|
||||
if (!this.args.model.isNew) {
|
||||
try {
|
||||
await this.args.model.update({ priority: this.args.model.priority });
|
||||
|
|
|
@ -41,6 +41,7 @@ en:
|
|||
ai_openai_gpt35_16k_url: "Custom URL used for GPT 3.5 16k chat completions. (for Azure support)"
|
||||
ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)"
|
||||
ai_openai_gpt4_32k_url: "Custom URL used for GPT 4 32k chat completions. (for Azure support)"
|
||||
ai_openai_dall_e_3_url: "Custom URL used for DALL-E 3 image generation. (for Azure support)"
|
||||
ai_openai_organization: "(Optional, leave empty to omit) Organization id used for the OpenAI API. Passed in using the OpenAI-Organization header."
|
||||
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
||||
ai_openai_api_key: "API key for OpenAI API"
|
||||
|
|
|
@ -92,6 +92,7 @@ discourse_ai:
|
|||
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||
ai_openai_organization: ""
|
||||
ai_openai_api_key:
|
||||
|
|
|
@ -54,15 +54,21 @@ module DiscourseAi::AiBot::Commands
|
|||
# this ensures multisite safety since background threads
|
||||
# generate the images
|
||||
api_key = SiteSetting.ai_openai_api_key
|
||||
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||
|
||||
threads = []
|
||||
prompts.each_with_index do |prompt, index|
|
||||
threads << Thread.new(prompt) do |inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(inner_prompt, api_key: api_key)
|
||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
|
||||
inner_prompt,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
sleep 2
|
||||
retry if attempts < 3
|
||||
Discourse.warn_exception(e, message: "Failed to generate image for prompt #{prompt}")
|
||||
nil
|
||||
|
|
|
@ -96,7 +96,7 @@ module DiscourseAi::AiBot::Commands
|
|||
results = threads.map(&:value).compact
|
||||
|
||||
if !results.present?
|
||||
return { prompt: prompt, error: "Something went wrong, could not generate image" }
|
||||
return { prompts: prompts, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
|
|
@ -14,19 +14,23 @@ module DiscourseAi
|
|||
|
||||
def system_prompt
|
||||
<<~PROMPT
|
||||
You are a bot specializing in generating images using DALL-E-3
|
||||
As a DALL-E-3 bot, you're tasked with generating images based on user prompts.
|
||||
|
||||
- A good prompt needs to be detailed and specific.
|
||||
- You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it)
|
||||
- You can specify details about lighting or time of day.
|
||||
- You can specify a particular website you would like to emulate (artstation or deviantart)
|
||||
- You can specify additional details such as "beutiful, dystopian, futuristic, etc."
|
||||
- Prompts should generally be 40-80 words long, keep in mind API only accepts a maximum of 5000 chars per prompt
|
||||
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||
- Be specific and detailed in your prompts. Include elements like subject, medium (e.g., oil on canvas), artist style, lighting, time of day, and website style (e.g., ArtStation, DeviantArt).
|
||||
- Add adjectives for more detail (e.g., beautiful, dystopian, futuristic).
|
||||
- Prompts should be 40-100 words long, but remember the API accepts a maximum of 5000 characters per prompt.
|
||||
- Enhance short, vague user prompts with your own creative details.
|
||||
- Unless specified, generate 4 images per prompt.
|
||||
- Don't seek user permission before generating images or run the prompts by the user. Generate immediately to save tokens.
|
||||
|
||||
Example:
|
||||
|
||||
User: "a cow"
|
||||
You: Generate images immediately, without telling the user anything. Details will be provided to user with the generated images.
|
||||
|
||||
DO NOT SAY "I will generate the following ... image 1 description ... image 2 description ... etc."
|
||||
Just generate the images
|
||||
|
||||
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
|
||||
- Be creative with your prompts, offer diverse options
|
||||
- DALL-E-3 will rewrite your prompt to be more specific and detailed, use that one iterating on images
|
||||
PROMPT
|
||||
end
|
||||
end
|
||||
|
|
|
@ -50,7 +50,12 @@ module ::DiscourseAi
|
|||
|
||||
type = parameter[:type]
|
||||
if type == "array"
|
||||
arguments[name] = JSON.parse(value)
|
||||
begin
|
||||
arguments[name] = JSON.parse(value)
|
||||
rescue JSON::ParserError
|
||||
# maybe LLM chose a different shape for the array
|
||||
arguments[name] = value.to_s.split("\n").map(&:strip).reject(&:blank?)
|
||||
end
|
||||
elsif type == "integer"
|
||||
arguments[name] = value.to_i
|
||||
elsif type == "float"
|
||||
|
|
|
@ -5,23 +5,38 @@ module ::DiscourseAi
|
|||
class OpenAiImageGenerator
|
||||
TIMEOUT = 60
|
||||
|
||||
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil)
|
||||
def self.perform!(prompt, model: "dall-e-3", size: "1024x1024", api_key: nil, api_url: nil)
|
||||
api_key ||= SiteSetting.ai_openai_api_key
|
||||
api_url ||= SiteSetting.ai_openai_dall_e_3_url
|
||||
|
||||
url = URI("https://api.openai.com/v1/images/generations")
|
||||
headers = { "Content-Type" => "application/json", "Authorization" => "Bearer #{api_key}" }
|
||||
uri = URI(api_url)
|
||||
|
||||
payload = { model: model, prompt: prompt, n: 1, size: size, response_format: "b64_json" }
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if uri.host.include?("azure")
|
||||
headers["api-key"] = api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
end
|
||||
|
||||
payload = {
|
||||
quality: "hd",
|
||||
model: model,
|
||||
prompt: prompt,
|
||||
n: 1,
|
||||
size: size,
|
||||
response_format: "b64_json",
|
||||
}
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: url.scheme == "https",
|
||||
uri.host,
|
||||
uri.port,
|
||||
use_ssl: uri.scheme == "https",
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url, headers)
|
||||
request = Net::HTTP::Post.new(uri, headers)
|
||||
request.body = payload.to_json
|
||||
|
||||
json = nil
|
||||
|
|
|
@ -7,6 +7,39 @@ RSpec.describe DiscourseAi::AiBot::Commands::DallECommand do
|
|||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info with azure" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
SiteSetting.ai_openai_api_key = "abc"
|
||||
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
|
||||
|
||||
image =
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body, symbolize_names: true)
|
||||
|
||||
expect(prompts).to include(json[:prompt])
|
||||
expect(request.headers["Api-Key"]).to eq("abc")
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: { data: data }.to_json)
|
||||
|
||||
image = described_class.new(bot: bot, post: post, args: nil)
|
||||
|
||||
info = image.process(prompts: prompts).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow 1")
|
||||
end
|
||||
|
||||
it "can generate correct info" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
|
|
Loading…
Reference in New Issue