FIX: Malformed message in systemless + inline img scenario (#771)
This commit is contained in:
parent
eac83eb619
commit
c6aeabbfc0
|
@ -154,7 +154,6 @@ module DiscourseAi
|
||||||
upload_ids: [upload.id],
|
upload_ids: [upload.id],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
skip_validations: true,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
|
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate(
|
||||||
|
|
|
@ -29,9 +29,16 @@ module DiscourseAi
|
||||||
|
|
||||||
return translated unless llm_model.lookup_custom_param("disable_system_prompt")
|
return translated unless llm_model.lookup_custom_param("disable_system_prompt")
|
||||||
|
|
||||||
system_and_user_msgs = translated.shift(2)
|
system_msg, user_msg = translated.shift(2)
|
||||||
user_msg = system_and_user_msgs.last
|
|
||||||
user_msg[:content] = [system_and_user_msgs.first[:content], user_msg[:content]].join("\n")
|
if user_msg[:content].is_a?(Array) # Has inline images.
|
||||||
|
user_msg[:content].first[:text] = [
|
||||||
|
system_msg[:content],
|
||||||
|
user_msg[:content].first[:text],
|
||||||
|
].join("\n")
|
||||||
|
else
|
||||||
|
user_msg[:content] = [system_msg[:content], user_msg[:content]].join("\n")
|
||||||
|
end
|
||||||
|
|
||||||
translated.unshift(user_msg)
|
translated.unshift(user_msg)
|
||||||
end
|
end
|
||||||
|
@ -79,7 +86,7 @@ module DiscourseAi
|
||||||
return content if encoded_uploads.blank?
|
return content if encoded_uploads.blank?
|
||||||
|
|
||||||
content_w_imgs =
|
content_w_imgs =
|
||||||
encoded_uploads.reduce([]) do |memo, details|
|
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
||||||
memo << {
|
memo << {
|
||||||
type: "image_url",
|
type: "image_url",
|
||||||
image_url: {
|
image_url: {
|
||||||
|
@ -87,8 +94,6 @@ module DiscourseAi
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
content_w_imgs << { type: "text", text: message[:content] }
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -12,7 +12,6 @@ module DiscourseAi
|
||||||
system_message_text = nil,
|
system_message_text = nil,
|
||||||
messages: [],
|
messages: [],
|
||||||
tools: [],
|
tools: [],
|
||||||
skip_validations: false,
|
|
||||||
topic_id: nil,
|
topic_id: nil,
|
||||||
post_id: nil,
|
post_id: nil,
|
||||||
max_pixels: nil
|
max_pixels: nil
|
||||||
|
@ -26,7 +25,6 @@ module DiscourseAi
|
||||||
@post_id = post_id
|
@post_id = post_id
|
||||||
|
|
||||||
@messages = []
|
@messages = []
|
||||||
@skip_validations = skip_validations
|
|
||||||
|
|
||||||
if system_message_text
|
if system_message_text
|
||||||
system_message = { type: :system, content: system_message_text }
|
system_message = { type: :system, content: system_message_text }
|
||||||
|
@ -68,7 +66,6 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def validate_message(message)
|
def validate_message(message)
|
||||||
return if @skip_validations
|
|
||||||
valid_types = %i[system user model tool tool_call]
|
valid_types = %i[system user model tool tool_call]
|
||||||
if !valid_types.include?(message[:type])
|
if !valid_types.include?(message[:type])
|
||||||
raise ArgumentError, "message type must be one of #{valid_types}"
|
raise ArgumentError, "message type must be one of #{valid_types}"
|
||||||
|
@ -91,7 +88,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def validate_turn(last_turn, new_turn)
|
def validate_turn(last_turn, new_turn)
|
||||||
return if @skip_validations
|
|
||||||
valid_types = %i[tool tool_call model user]
|
valid_types = %i[tool tool_call model user]
|
||||||
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
|
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,10 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
||||||
context "when system prompts are disabled" do
|
context "when system prompts are disabled" do
|
||||||
|
fab!(:model) do
|
||||||
|
Fabricate(:vllm_model, vision_enabled: true, provider_params: { disable_system_prompt: true })
|
||||||
|
end
|
||||||
|
|
||||||
it "merges the system prompt into the first message" do
|
it "merges the system prompt into the first message" do
|
||||||
system_msg = "This is a system message"
|
system_msg = "This is a system message"
|
||||||
user_msg = "user message"
|
user_msg = "user message"
|
||||||
|
@ -11,8 +15,6 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
||||||
messages: [{ type: :user, content: user_msg }],
|
messages: [{ type: :user, content: user_msg }],
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: true })
|
|
||||||
|
|
||||||
translated_messages = described_class.new(prompt, model).translate
|
translated_messages = described_class.new(prompt, model).translate
|
||||||
|
|
||||||
expect(translated_messages.length).to eq(1)
|
expect(translated_messages.length).to eq(1)
|
||||||
|
@ -20,6 +22,49 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
||||||
{ role: "user", content: [system_msg, user_msg].join("\n") },
|
{ role: "user", content: [system_msg, user_msg].join("\n") },
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
context "when the prompt has inline images" do
|
||||||
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
|
|
||||||
|
it "produces a valid message" do
|
||||||
|
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are a bot specializing in image captioning.",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
type: :user,
|
||||||
|
content: "Describe this image in a single sentence.",
|
||||||
|
upload_ids: [upload.id],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
encoded_upload =
|
||||||
|
DiscourseAi::Completions::UploadEncoder.encode(
|
||||||
|
upload_ids: [upload.id],
|
||||||
|
max_pixels: prompt.max_pixels,
|
||||||
|
).first
|
||||||
|
|
||||||
|
translated_messages = described_class.new(prompt, model).translate
|
||||||
|
|
||||||
|
expect(translated_messages.length).to eq(1)
|
||||||
|
|
||||||
|
expected_user_message = {
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{ type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") },
|
||||||
|
{
|
||||||
|
type: "image_url",
|
||||||
|
image_url: {
|
||||||
|
url: "data:#{encoded_upload[:mime_type]};base64,#{encoded_upload[:base64]}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(translated_messages).to contain_exactly(expected_user_message)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "when system prompts are enabled" do
|
context "when system prompts are enabled" do
|
||||||
|
|
Loading…
Reference in New Issue