FEATURE: add aspect ratio support to DallE 3 (#647)

DallE 3 supports tall/square and wide images.

This adds support to the 3 variants. (wide / tall / square)
This commit is contained in:
Sam 2024-05-28 16:21:40 +10:00 committed by GitHub
parent dae9d6f14e
commit 309280cbb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 12 deletions

View File

@ -17,6 +17,13 @@ module DiscourseAi
item_type: "string", item_type: "string",
required: true, required: true,
}, },
{
name: "aspect_ratio",
description: "The aspect ratio (optional, square by default)",
type: "string",
required: false,
enum: %w[tall square wide],
},
], ],
} }
end end
@ -29,6 +36,10 @@ module DiscourseAi
parameters[:prompts] parameters[:prompts]
end end
def aspect_ratio
parameters[:aspect_ratio]
end
def chain_next_response? def chain_next_response?
false false
end end
@ -47,6 +58,13 @@ module DiscourseAi
api_key = SiteSetting.ai_openai_api_key api_key = SiteSetting.ai_openai_api_key
api_url = SiteSetting.ai_openai_dall_e_3_url api_url = SiteSetting.ai_openai_dall_e_3_url
size = "1024x1024"
if aspect_ratio == "tall"
size = "1024x1792"
elsif aspect_ratio == "wide"
size = "1792x1024"
end
threads = [] threads = []
max_prompts.each_with_index do |prompt, index| max_prompts.each_with_index do |prompt, index|
threads << Thread.new(prompt) do |inner_prompt| threads << Thread.new(prompt) do |inner_prompt|
@ -54,6 +72,7 @@ module DiscourseAi
begin begin
DiscourseAi::Inference::OpenAiImageGenerator.perform!( DiscourseAi::Inference::OpenAiImageGenerator.perform!(
inner_prompt, inner_prompt,
size: size,
api_key: api_key, api_key: api_key,
api_url: api_url, api_url: api_url,
) )

View File

@ -11,19 +11,49 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {}) described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
end end
before { SiteSetting.ai_bot_enabled = true } let(:base64_image) do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
end
before do
SiteSetting.ai_bot_enabled = true
SiteSetting.ai_openai_api_key = "abc"
end
describe "#process" do describe "#process" do
it "can generate tall images" do
generator =
described_class.new(
{ prompts: ["a cat"], aspect_ratio: "tall" },
llm: llm,
bot_user: bot_user,
context: {
},
)
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]
WebMock
.stub_request(:post, "https://api.openai.com/v1/images/generations")
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
expect(json[:prompt]).to eq("a cat")
expect(json[:size]).to eq("1024x1792")
true
end
.to_return(status: 200, body: { data: data }.to_json)
info = generator.invoke(&progress_blk).to_json
expect(JSON.parse(info)).to eq("prompts" => ["a tall cat"])
end
it "can generate correct info with azure" do it "can generate correct info with azure" do
_post = Fabricate(:post) _post = Fabricate(:post)
SiteSetting.ai_openai_api_key = "abc"
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url" SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
image = data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
WebMock WebMock
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url) .stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
@ -45,12 +75,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
end end
it "can generate correct info" do it "can generate correct info" do
SiteSetting.ai_openai_api_key = "abc" data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/images/generations") .stub_request(:post, "https://api.openai.com/v1/images/generations")