FIX: Malformed message in systemless + inline img scenario (#771)

This commit is contained in:
Roman Rizzi 2024-08-23 16:41:57 -03:00 committed by GitHub
parent eac83eb619
commit c6aeabbfc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 13 deletions

View File

@ -154,7 +154,6 @@ module DiscourseAi
upload_ids: [upload.id], upload_ids: [upload.id],
}, },
], ],
skip_validations: true,
) )
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate( DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(

View File

@ -29,9 +29,16 @@ module DiscourseAi
return translated unless llm_model.lookup_custom_param("disable_system_prompt") return translated unless llm_model.lookup_custom_param("disable_system_prompt")
system_and_user_msgs = translated.shift(2) system_msg, user_msg = translated.shift(2)
user_msg = system_and_user_msgs.last
user_msg[:content] = [system_and_user_msgs.first[:content], user_msg[:content]].join("\n") if user_msg[:content].is_a?(Array) # Has inline images.
user_msg[:content].first[:text] = [
system_msg[:content],
user_msg[:content].first[:text],
].join("\n")
else
user_msg[:content] = [system_msg[:content], user_msg[:content]].join("\n")
end
translated.unshift(user_msg) translated.unshift(user_msg)
end end
@ -79,7 +86,7 @@ module DiscourseAi
return content if encoded_uploads.blank? return content if encoded_uploads.blank?
content_w_imgs = content_w_imgs =
encoded_uploads.reduce([]) do |memo, details| encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
memo << { memo << {
type: "image_url", type: "image_url",
image_url: { image_url: {
@ -87,8 +94,6 @@ module DiscourseAi
}, },
} }
end end
content_w_imgs << { type: "text", text: message[:content] }
end end
end end
end end

View File

@ -12,7 +12,6 @@ module DiscourseAi
system_message_text = nil, system_message_text = nil,
messages: [], messages: [],
tools: [], tools: [],
skip_validations: false,
topic_id: nil, topic_id: nil,
post_id: nil, post_id: nil,
max_pixels: nil max_pixels: nil
@ -26,7 +25,6 @@ module DiscourseAi
@post_id = post_id @post_id = post_id
@messages = [] @messages = []
@skip_validations = skip_validations
if system_message_text if system_message_text
system_message = { type: :system, content: system_message_text } system_message = { type: :system, content: system_message_text }
@ -68,7 +66,6 @@ module DiscourseAi
private private
def validate_message(message) def validate_message(message)
return if @skip_validations
valid_types = %i[system user model tool tool_call] valid_types = %i[system user model tool tool_call]
if !valid_types.include?(message[:type]) if !valid_types.include?(message[:type])
raise ArgumentError, "message type must be one of #{valid_types}" raise ArgumentError, "message type must be one of #{valid_types}"
@ -91,7 +88,6 @@ module DiscourseAi
end end
def validate_turn(last_turn, new_turn) def validate_turn(last_turn, new_turn)
return if @skip_validations
valid_types = %i[tool tool_call model user] valid_types = %i[tool tool_call model user]
raise INVALID_TURN if !valid_types.include?(new_turn[:type]) raise INVALID_TURN if !valid_types.include?(new_turn[:type])

View File

@ -2,6 +2,10 @@
RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
context "when system prompts are disabled" do context "when system prompts are disabled" do
fab!(:model) do
Fabricate(:vllm_model, vision_enabled: true, provider_params: { disable_system_prompt: true })
end
it "merges the system prompt into the first message" do it "merges the system prompt into the first message" do
system_msg = "This is a system message" system_msg = "This is a system message"
user_msg = "user message" user_msg = "user message"
@ -11,8 +15,6 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
messages: [{ type: :user, content: user_msg }], messages: [{ type: :user, content: user_msg }],
) )
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: true })
translated_messages = described_class.new(prompt, model).translate translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(1) expect(translated_messages.length).to eq(1)
@ -20,6 +22,49 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
{ role: "user", content: [system_msg, user_msg].join("\n") }, { role: "user", content: [system_msg, user_msg].join("\n") },
) )
end end
context "when the prompt has inline images" do
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
it "produces a valid message" do
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot specializing in image captioning.",
messages: [
{
type: :user,
content: "Describe this image in a single sentence.",
upload_ids: [upload.id],
},
],
)
encoded_upload =
DiscourseAi::Completions::UploadEncoder.encode(
upload_ids: [upload.id],
max_pixels: prompt.max_pixels,
).first
translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(1)
expected_user_message = {
role: "user",
content: [
{ type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") },
{
type: "image_url",
image_url: {
url: "data:#{encoded_upload[:mime_type]};base64,#{encoded_upload[:base64]}",
},
},
],
}
expect(translated_messages).to contain_exactly(expected_user_message)
end
end
end end
context "when system prompts are enabled" do context "when system prompts are enabled" do