2023-11-24 02:08:08 -05:00
|
|
|
#frozen_string_literal: true
|
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
|
|
|
let(:prompts) { ["a pink cow", "a red cow"] }
|
|
|
|
|
2024-06-18 13:32:14 -04:00
|
|
|
fab!(:gpt_35_turbo) { Fabricate(:llm_model, name: "gpt-3.5-turbo") }
|
|
|
|
|
|
|
|
before do
|
|
|
|
SiteSetting.ai_bot_enabled = true
|
|
|
|
toggle_enabled_bots(bots: [gpt_35_turbo])
|
|
|
|
SiteSetting.ai_openai_api_key = "abc"
|
|
|
|
end
|
|
|
|
|
|
|
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
2024-01-29 14:04:25 -05:00
|
|
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
2024-01-04 08:44:07 -05:00
|
|
|
let(:progress_blk) { Proc.new {} }
|
2023-11-24 02:08:08 -05:00
|
|
|
|
2024-05-07 07:55:46 -04:00
|
|
|
let(:dall_e) do
|
|
|
|
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
|
|
|
|
end
|
|
|
|
|
2024-05-28 02:21:40 -04:00
|
|
|
let(:base64_image) do
|
|
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
|
|
|
end
|
|
|
|
|
2023-11-24 02:08:08 -05:00
|
|
|
describe "#process" do
|
2024-05-28 02:21:40 -04:00
|
|
|
it "can generate tall images" do
|
|
|
|
generator =
|
|
|
|
described_class.new(
|
|
|
|
{ prompts: ["a cat"], aspect_ratio: "tall" },
|
|
|
|
llm: llm,
|
|
|
|
bot_user: bot_user,
|
|
|
|
context: {
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]
|
|
|
|
|
|
|
|
WebMock
|
|
|
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
|
|
|
.with do |request|
|
|
|
|
json = JSON.parse(request.body, symbolize_names: true)
|
|
|
|
|
|
|
|
expect(json[:prompt]).to eq("a cat")
|
|
|
|
expect(json[:size]).to eq("1024x1792")
|
|
|
|
true
|
|
|
|
end
|
|
|
|
.to_return(status: 200, body: { data: data }.to_json)
|
|
|
|
|
|
|
|
info = generator.invoke(&progress_blk).to_json
|
|
|
|
expect(JSON.parse(info)).to eq("prompts" => ["a tall cat"])
|
|
|
|
end
|
|
|
|
|
2023-11-26 21:01:05 -05:00
|
|
|
it "can generate correct info with azure" do
|
2024-01-05 13:21:14 -05:00
|
|
|
_post = Fabricate(:post)
|
2023-11-26 21:01:05 -05:00
|
|
|
|
|
|
|
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
|
|
|
|
|
2024-05-28 02:21:40 -04:00
|
|
|
data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
|
2023-11-26 21:01:05 -05:00
|
|
|
|
|
|
|
WebMock
|
|
|
|
.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url)
|
|
|
|
.with do |request|
|
|
|
|
json = JSON.parse(request.body, symbolize_names: true)
|
|
|
|
|
|
|
|
expect(prompts).to include(json[:prompt])
|
|
|
|
expect(request.headers["Api-Key"]).to eq("abc")
|
|
|
|
true
|
|
|
|
end
|
|
|
|
.to_return(status: 200, body: { data: data }.to_json)
|
|
|
|
|
2024-05-07 07:55:46 -04:00
|
|
|
info = dall_e.invoke(&progress_blk).to_json
|
2023-11-26 21:01:05 -05:00
|
|
|
|
|
|
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
2024-05-07 07:55:46 -04:00
|
|
|
expect(dall_e.custom_raw).to include("upload://")
|
|
|
|
expect(dall_e.custom_raw).to include("[grid]")
|
|
|
|
expect(dall_e.custom_raw).to include("a pink cow 1")
|
2023-11-26 21:01:05 -05:00
|
|
|
end
|
|
|
|
|
2023-11-24 02:08:08 -05:00
|
|
|
it "can generate correct info" do
|
2024-05-28 02:21:40 -04:00
|
|
|
data = [{ b64_json: base64_image, revised_prompt: "a pink cow 1" }]
|
2023-11-24 02:08:08 -05:00
|
|
|
|
|
|
|
WebMock
|
|
|
|
.stub_request(:post, "https://api.openai.com/v1/images/generations")
|
|
|
|
.with do |request|
|
|
|
|
json = JSON.parse(request.body, symbolize_names: true)
|
|
|
|
expect(prompts).to include(json[:prompt])
|
|
|
|
true
|
|
|
|
end
|
|
|
|
.to_return(status: 200, body: { data: data }.to_json)
|
|
|
|
|
2024-05-07 07:55:46 -04:00
|
|
|
info = dall_e.invoke(&progress_blk).to_json
|
2023-11-24 02:08:08 -05:00
|
|
|
|
|
|
|
expect(JSON.parse(info)).to eq("prompts" => ["a pink cow 1", "a pink cow 1"])
|
2024-05-07 07:55:46 -04:00
|
|
|
expect(dall_e.custom_raw).to include("upload://")
|
|
|
|
expect(dall_e.custom_raw).to include("[grid]")
|
|
|
|
expect(dall_e.custom_raw).to include("a pink cow 1")
|
2023-11-24 02:08:08 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|