diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 6af4b10f..ce390d8e 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -110,7 +110,7 @@ module DiscourseAi end def user_msg(msg) - content = prompt.text_only(msg) + content = DiscourseAi::Completions::Prompt.text_only(msg) user_message = { role: "USER", message: content } user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id] user_message diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb index 9dc88097..10098c26 100644 --- a/lib/completions/dialects/nova.rb +++ b/lib/completions/dialects/nova.rb @@ -156,7 +156,7 @@ module DiscourseAi end end - { role: "user", content: prompt.text_only(msg), images: images } + { role: "user", content: DiscourseAi::Completions::Prompt.text_only(msg), images: images } end def model_msg(msg) diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index fe31bc1d..60d58455 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -69,7 +69,7 @@ module DiscourseAi end def user_msg(msg) - user_message = { role: "user", content: prompt.text_only(msg) } + user_message = { role: "user", content: DiscourseAi::Completions::Prompt.text_only(msg) } encoded_uploads = prompt.encoded_uploads(msg) if encoded_uploads.present? diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 8810257c..0641b64f 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -8,6 +8,14 @@ module DiscourseAi attr_reader :messages, :tools, :system_message_text attr_accessor :topic_id, :post_id, :max_pixels, :tool_choice + def self.text_only(message) + if message[:content].is_a?(Array) + message[:content].map { |element| element if element.is_a?(String) }.compact.join + else + message[:content] + end + end + def initialize( system_message_text = nil, messages: [], @@ -146,14 +154,6 @@ module DiscourseAi [] end - def text_only(message) - if message[:content].is_a?(Array) - message[:content].map { |element| element if element.is_a?(String) }.compact.join - else - message[:content] - end - end - def encode_upload(upload_id) UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first end diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index d1483f0a..ff9affcc 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -365,7 +365,7 @@ module DiscourseAi # first response if latest_interactions.length == 1 - consolidated_question = latest_interactions[0][:content] + consolidated_question = DiscourseAi::Completions::Prompt.text_only(latest_interactions[0]) else consolidated_question = DiscourseAi::Personas::QuestionConsolidator.consolidate_question( diff --git a/lib/personas/question_consolidator.rb b/lib/personas/question_consolidator.rb index f1e0c476..d89716a5 100644 --- a/lib/personas/question_consolidator.rb +++ b/lib/personas/question_consolidator.rb @@ -33,7 +33,7 @@ module DiscourseAi row = +"" row << ((message[:type] == :user) ? "user" : "model") - content = message[:content] + content = DiscourseAi::Completions::Prompt.text_only(message) current_tokens = @llm.tokenizer.tokenize(content).length allowed_tokens = @max_tokens - tokens diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb index fe310ef8..6c226677 100644 --- a/spec/lib/personas/persona_spec.rb +++ b/spec/lib/personas/persona_spec.rb @@ -306,21 +306,24 @@ RSpec.describe DiscourseAi::Personas::Persona do fab!(:llm_model) { Fabricate(:fake_model) } - it "will run the question consolidator" do + fab!(:custom_ai_persona) do + Fabricate( + :ai_persona, + name: "custom", + rag_conversation_chunks: 3, + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + question_consolidator_llm_id: llm_model.id, + ) + end + + before do context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) } EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding) - custom_ai_persona = - Fabricate( - :ai_persona, - name: "custom", - rag_conversation_chunks: 3, - allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], - question_consolidator_llm_id: llm_model.id, - ) - UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) + end + it "will run the question consolidator" do custom_persona = DiscourseAi::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new @@ -343,6 +346,36 @@ RSpec.describe DiscourseAi::Personas::Persona do expect(message).to include("the time is 1") expect(message).to include("in france?") end + + context "when there are messages with uploads" do + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + let(:image_upload) do + UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + end + + it "the question consolidator works" do + custom_persona = + DiscourseAi::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new + + context.messages = [ + { content: "Tell me the time", type: :user }, + { content: "the time is 1", type: :model }, + { content: ["in france?", { upload_id: image_upload.id }], type: :user }, + ] + + DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do + custom_persona.craft_prompt(context).messages.first[:content] + end + + message = + DiscourseAi::Completions::Endpoints::Fake.last_call[:dialect].prompt.messages.last[ + :content + ] + expect(message).to include("Tell me the time") + expect(message).to include("the time is 1") + expect(message).to include("in france?") + end + end end context "when a persona has RAG uploads" do