From c6aeabbfc018163fcb5ad192419896450960f1d5 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 23 Aug 2024 16:41:57 -0300 Subject: [PATCH] FIX: Malformed message in systemless + inline img scenario (#771) --- lib/ai_helper/assistant.rb | 1 - .../dialects/open_ai_compatible.rb | 17 ++++--- lib/completions/prompt.rb | 4 -- .../dialects/open_ai_compatible_spec.rb | 49 ++++++++++++++++++- 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index e95b2841..2e2c303d 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -154,7 +154,6 @@ module DiscourseAi upload_ids: [upload.id], }, ], - skip_validations: true, ) DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_image_caption_model).generate( diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index 0ed2a1d8..5a679b71 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -29,9 +29,16 @@ module DiscourseAi return translated unless llm_model.lookup_custom_param("disable_system_prompt") - system_and_user_msgs = translated.shift(2) - user_msg = system_and_user_msgs.last - user_msg[:content] = [system_and_user_msgs.first[:content], user_msg[:content]].join("\n") + system_msg, user_msg = translated.shift(2) + + if user_msg[:content].is_a?(Array) # Has inline images. + user_msg[:content].first[:text] = [ + system_msg[:content], + user_msg[:content].first[:text], + ].join("\n") + else + user_msg[:content] = [system_msg[:content], user_msg[:content]].join("\n") + end translated.unshift(user_msg) end @@ -79,7 +86,7 @@ module DiscourseAi return content if encoded_uploads.blank? content_w_imgs = - encoded_uploads.reduce([]) do |memo, details| + encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details| memo << { type: "image_url", image_url: { @@ -87,8 +94,6 @@ module DiscourseAi }, } end - - content_w_imgs << { type: "text", text: message[:content] } end end end diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 4e06d950..051818d2 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -12,7 +12,6 @@ module DiscourseAi system_message_text = nil, messages: [], tools: [], - skip_validations: false, topic_id: nil, post_id: nil, max_pixels: nil @@ -26,7 +25,6 @@ module DiscourseAi @post_id = post_id @messages = [] - @skip_validations = skip_validations if system_message_text system_message = { type: :system, content: system_message_text } @@ -68,7 +66,6 @@ module DiscourseAi private def validate_message(message) - return if @skip_validations valid_types = %i[system user model tool tool_call] if !valid_types.include?(message[:type]) raise ArgumentError, "message type must be one of #{valid_types}" @@ -91,7 +88,6 @@ module DiscourseAi end def validate_turn(last_turn, new_turn) - return if @skip_validations valid_types = %i[tool tool_call model user] raise INVALID_TURN if !valid_types.include?(new_turn[:type]) diff --git a/spec/lib/completions/dialects/open_ai_compatible_spec.rb b/spec/lib/completions/dialects/open_ai_compatible_spec.rb index 7da85afa..00dea4b0 100644 --- a/spec/lib/completions/dialects/open_ai_compatible_spec.rb +++ b/spec/lib/completions/dialects/open_ai_compatible_spec.rb @@ -2,6 +2,10 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do context "when system prompts are disabled" do + fab!(:model) do + Fabricate(:vllm_model, vision_enabled: true, provider_params: { disable_system_prompt: true }) + end + it "merges the system prompt into the first message" do system_msg = "This is a system message" user_msg = "user message" @@ -11,8 +15,6 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do messages: [{ type: :user, content: user_msg }], ) - model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: true }) - translated_messages = described_class.new(prompt, model).translate expect(translated_messages.length).to eq(1) @@ -20,6 +22,49 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do { role: "user", content: [system_msg, user_msg].join("\n") }, ) end + + context "when the prompt has inline images" do + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + + it "produces a valid message" do + upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot specializing in image captioning.", + messages: [ + { + type: :user, + content: "Describe this image in a single sentence.", + upload_ids: [upload.id], + }, + ], + ) + encoded_upload = + DiscourseAi::Completions::UploadEncoder.encode( + upload_ids: [upload.id], + max_pixels: prompt.max_pixels, + ).first + + translated_messages = described_class.new(prompt, model).translate + + expect(translated_messages.length).to eq(1) + + expected_user_message = { + role: "user", + content: [ + { type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") }, + { + type: "image_url", + image_url: { + url: "data:#{encoded_upload[:mime_type]};base64,#{encoded_upload[:base64]}", + }, + }, + ], + } + + expect(translated_messages).to contain_exactly(expected_user_message) + end + end end context "when system prompts are enabled" do