diff --git a/lib/ai_bot/tools/dall_e.rb b/lib/ai_bot/tools/dall_e.rb index 3686ba97..7389a3b5 100644 --- a/lib/ai_bot/tools/dall_e.rb +++ b/lib/ai_bot/tools/dall_e.rb @@ -17,6 +17,13 @@ module DiscourseAi item_type: "string", required: true, }, + { + name: "aspect_ratio", + description: "The aspect ratio (optional, square by default)", + type: "string", + required: false, + enum: %w[tall square wide], + }, ], } end @@ -29,6 +36,10 @@ module DiscourseAi parameters[:prompts] end + def aspect_ratio + parameters[:aspect_ratio] + end + def chain_next_response? false end @@ -47,6 +58,13 @@ module DiscourseAi api_key = SiteSetting.ai_openai_api_key api_url = SiteSetting.ai_openai_dall_e_3_url + size = "1024x1024" + if aspect_ratio == "tall" + size = "1024x1792" + elsif aspect_ratio == "wide" + size = "1792x1024" + end + threads = [] max_prompts.each_with_index do |prompt, index| threads << Thread.new(prompt) do |inner_prompt| @@ -54,6 +72,7 @@ module DiscourseAi begin DiscourseAi::Inference::OpenAiImageGenerator.perform!( inner_prompt, + size: size, api_key: api_key, api_url: api_url, ) diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb index 3d5f5a13..1e536a07 100644 --- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb +++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb @@ -11,19 +11,49 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {}) end - before { SiteSetting.ai_bot_enabled = true } + let(:base64_image) do + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + end + + before do + SiteSetting.ai_bot_enabled = true + SiteSetting.ai_openai_api_key = "abc" + end describe "#process" do + it "can generate tall images" do + generator = + described_class.new( + { prompts: ["a cat"], aspect_ratio: "tall" }, + llm: llm, + bot_user: bot_user, + context: { + }, + ) + + data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }] + + WebMock + .stub_request(:post, "https://api.openai.com/v1/images/generations") + .with do |request| + json = JSON.parse(request.body, symbolize_names: true) + + expect(json[:prompt]).to eq("a cat") + expect(json[:size]).to eq("1024x1792") + true + end + .to_return(status: 200, body: { data: data }.to_json) + + info = generator.invoke(&progress_blk).to_json + expect(JSON.parse(info)).to eq("prompts" => ["a tall cat"]) + end + it "can generate correct info with azure" do _post = Fabricate(:post) - SiteSetting.ai_openai_api_key = "abc" SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url" - image = - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" - - data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] + data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }] WebMock .stub_request(:post, SiteSetting.ai_openai_dall_e_3_url) @@ -45,12 +75,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do end it "can generate correct info" do - SiteSetting.ai_openai_api_key = "abc" - - image = - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" - - data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] + data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }] WebMock .stub_request(:post, "https://api.openai.com/v1/images/generations")