FEATURE: add aspect ratio support to DallE 3 (#647)
DallE 3 supports tall/square and wide images. This adds support to the 3 variants. (wide / tall / square)
This commit is contained in:
parent
dae9d6f14e
commit
309280cbb6
|
@ -17,6 +17,13 @@ module DiscourseAi
|
||||||
item_type: "string",
|
item_type: "string",
|
||||||
required: true,
|
required: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "aspect_ratio",
|
||||||
|
description: "The aspect ratio (optional, square by default)",
|
||||||
|
type: "string",
|
||||||
|
required: false,
|
||||||
|
enum: %w[tall square wide],
|
||||||
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
@ -29,6 +36,10 @@ module DiscourseAi
|
||||||
parameters[:prompts]
|
parameters[:prompts]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def aspect_ratio
|
||||||
|
parameters[:aspect_ratio]
|
||||||
|
end
|
||||||
|
|
||||||
def chain_next_response?
|
def chain_next_response?
|
||||||
false
|
false
|
||||||
end
|
end
|
||||||
|
@ -47,6 +58,13 @@ module DiscourseAi
|
||||||
api_key = SiteSetting.ai_openai_api_key
|
api_key = SiteSetting.ai_openai_api_key
|
||||||
api_url = SiteSetting.ai_openai_dall_e_3_url
|
api_url = SiteSetting.ai_openai_dall_e_3_url
|
||||||
|
|
||||||
|
size = "1024x1024"
|
||||||
|
if aspect_ratio == "tall"
|
||||||
|
size = "1024x1792"
|
||||||
|
elsif aspect_ratio == "wide"
|
||||||
|
size = "1792x1024"
|
||||||
|
end
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
max_prompts.each_with_index do |prompt, index|
|
max_prompts.each_with_index do |prompt, index|
|
||||||
threads << Thread.new(prompt) do |inner_prompt|
|
threads << Thread.new(prompt) do |inner_prompt|
|
||||||
|
@ -54,6 +72,7 @@ module DiscourseAi
|
||||||
begin
|
begin
|
||||||
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
|
DiscourseAi::Inference::OpenAiImageGenerator.perform!(
|
||||||
inner_prompt,
|
inner_prompt,
|
||||||
|
size: size,
|
||||||
api_key: api_key,
|
api_key: api_key,
|
||||||
api_url: api_url,
|
api_url: api_url,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,19 +11,49 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
|
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
|
||||||
end
|
end
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
let(:base64_image) do
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||||
|
end
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_bot_enabled = true
|
||||||
|
SiteSetting.ai_openai_api_key = "abc"
|
||||||
|
end
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
it "can generate tall images" do
|
||||||
|
generator =
|
||||||
|
described_class.new(
|
||||||
|
{ prompts: ["a cat"], aspect_ratio: "tall" },
|
||||||
|
llm: llm,
|
||||||
|
bot_user: bot_user,
|
||||||
|
context: {
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body, symbolize_names: true)
|
||||||
|
|
||||||
|
expect(json[:prompt]).to eq("a cat")
|
||||||
|
expect(json[:size]).to eq("1024x1792")
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: { data: data }.to_json)
|
||||||
|
|
||||||
|
info = generator.invoke(&progress_blk).to_json
|
||||||
|
expect(JSON.parse(info)).to eq("prompts" => ["a tall cat"])
|
||||||
|
end
|
||||||
|
|
||||||
it "can generate correct info with azure" do
|
it "can generate correct info with azure" do
|
||||||
_post = Fabricate(:post)
|
_post = Fabricate(:post)
|
||||||
|
|
||||||
SiteSetting.ai_openai_api_key = "abc"
|
|
||||||
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
|
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
|
||||||
|
|
||||||
image =
|
data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
|
||||||
|
|
||||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
|
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
|
||||||
|
@ -45,12 +75,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
SiteSetting.ai_openai_api_key = "abc"
|
data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
|
||||||
|
|
||||||
image =
|
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
|
||||||
|
|
||||||
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
||||||
|
|
Loading…
Reference in New Issue