diff --git a/app/controllers/discourse_ai/admin/ai_tools_controller.rb b/app/controllers/discourse_ai/admin/ai_tools_controller.rb index caf7fae6..6c148f1a 100644 --- a/app/controllers/discourse_ai/admin/ai_tools_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_tools_controller.rb @@ -55,7 +55,7 @@ module DiscourseAi # we need an llm so we have a tokenizer # but will do without if none is available llm = LlmModel.first&.to_llm - runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user, context: {}) + runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user) result = runner.invoke if result.is_a?(Hash) && result[:error] diff --git a/app/jobs/regular/stream_discover_reply.rb b/app/jobs/regular/stream_discover_reply.rb index b242616e..48183f6e 100644 --- a/app/jobs/regular/stream_discover_reply.rb +++ b/app/jobs/regular/stream_discover_reply.rb @@ -30,9 +30,13 @@ module Jobs base = { query: query, model_used: llm_model.display_name } - bot.reply( - { conversation_context: [{ type: :user, content: query }], skip_tool_details: true }, - ) do |partial| + context = + DiscourseAi::AiBot::BotContext.new( + messages: [{ type: :user, content: query }], + skip_tool_details: true, + ) + + bot.reply(context) do |partial| streamed_reply << partial # Throttle updates. diff --git a/app/models/ai_tool.rb b/app/models/ai_tool.rb index ba3b4098..d6fde8b2 100644 --- a/app/models/ai_tool.rb +++ b/app/models/ai_tool.rb @@ -35,7 +35,7 @@ class AiTool < ActiveRecord::Base tool_name.presence || name end - def runner(parameters, llm:, bot_user:, context: {}) + def runner(parameters, llm:, bot_user:, context: nil) DiscourseAi::AiBot::ToolRunner.new( parameters: parameters, llm: llm, @@ -59,86 +59,166 @@ class AiTool < ActiveRecord::Base def self.preamble <<~JS - /** - * Tool API Quick Reference - * - * Entry Functions - * - * invoke(parameters): Main function. Receives parameters (Object). Must return a JSON-serializable value. - * Example: - * function invoke(parameters) { return "result"; } - * - * details(): Optional. Returns a string describing the tool. - * Example: - * function details() { return "Tool description."; } - * - * Provided Objects - * - * 1. http - * http.get(url, options?): Performs an HTTP GET request. - * Parameters: - * url (string): The request URL. - * options (Object, optional): - * headers (Object): Request headers. - * Returns: - * { status: number, body: string } - * - * http.post(url, options?): Performs an HTTP POST request. - * Parameters: - * url (string): The request URL. - * options (Object, optional): - * headers (Object): Request headers. - * body (string): Request body. - * Returns: - * { status: number, body: string } - * - * (also available: http.put, http.patch, http.delete) - * - * Note: Max 20 HTTP requests per execution. - * - * 2. llm - * llm.truncate(text, length): Truncates text to a specified token length. - * Parameters: - * text (string): Text to truncate. - * length (number): Max tokens. - * Returns: - * Truncated string. - * - * 3. index - * index.search(query, options?): Searches indexed documents. - * Parameters: - * query (string): Search query. - * options (Object, optional): - * filenames (Array): Limit search to specific files. - * limit (number): Max fragments (up to 200). - * Returns: - * Array of { fragment: string, metadata: string } - * - * 4. upload - * upload.create(filename, base_64_content): Uploads a file. - * Parameters: - * filename (string): Name of the file. - * base_64_content (string): Base64 encoded file content. - * Returns: - * { id: number, short_url: string } - * - * 5. chain - * chain.setCustomRaw(raw): Sets the body of the post and exist chain. - * Parameters: - * raw (string): raw content to add to post. - * - * Constraints - * - * Execution Time: ≤ 2000ms - * Memory: ≤ 10MB - * HTTP Requests: ≤ 20 per execution - * Exceeding limits will result in errors or termination. - * - * Security - * - * Sandboxed Environment: No access to system or global objects. - * No File System Access: Cannot read or write files. - */ + /** + * Tool API Quick Reference + * + * Entry Functions + * + * invoke(parameters): Main function. Receives parameters defined in the tool's signature (Object). + * Must return a JSON-serializable value (e.g., string, number, object, array). + * Example: + * function invoke(parameters) { return { result: "Data processed", input: parameters.query }; } + * + * details(): Optional function. Returns a string (can include basic HTML) describing + * the tool's action after invocation, often using data from the invocation. + * This is displayed in the chat interface. + * Example: + * let lastUrl; + * function invoke(parameters) { + * lastUrl = parameters.url; + * // ... perform action ... + * return { success: true, content: "..." }; + * } + * function details() { + * return `Browsed: ${lastUrl}`; + * } + * + * Provided Objects & Functions + * + * 1. http + * Performs HTTP requests. Max 20 requests per execution. + * + * http.get(url, options?): Performs GET request. + * Parameters: + * url (string): The request URL. + * options (Object, optional): + * headers (Object): Request headers (e.g., { "Authorization": "Bearer key" }). + * Returns: { status: number, body: string } + * + * http.post(url, options?): Performs POST request. + * Parameters: + * url (string): The request URL. + * options (Object, optional): + * headers (Object): Request headers. + * body (string | Object): Request body. If an object, it's stringified as JSON. + * Returns: { status: number, body: string } + * + * http.put(url, options?): Performs PUT request (similar to POST). + * http.patch(url, options?): Performs PATCH request (similar to POST). + * http.delete(url, options?): Performs DELETE request (similar to GET/POST). + * + * 2. llm + * Interacts with the Language Model. + * + * llm.truncate(text, length): Truncates text to a specified token length based on the configured LLM's tokenizer. + * Parameters: + * text (string): Text to truncate. + * length (number): Maximum number of tokens. + * Returns: string (truncated text) + * + * llm.generate(prompt): Generates text using the configured LLM associated with the tool runner. + * Parameters: + * prompt (string | Object): The prompt. Can be a simple string or an object + * like { messages: [{ type: "system", content: "..." }, { type: "user", content: "..." }] }. + * Returns: string (generated text) + * + * 3. index + * Searches attached RAG (Retrieval-Augmented Generation) documents linked to this tool. + * + * index.search(query, options?): Searches indexed document fragments. + * Parameters: + * query (string): The search query used for semantic search. + * options (Object, optional): + * filenames (Array): Filter search to fragments from specific uploaded filenames. + * limit (number): Maximum number of fragments to return (default: 10, max: 200). + * Returns: Array<{ fragment: string, metadata: string | null }> - Ordered by relevance. + * + * 4. upload + * Handles file uploads within Discourse. + * + * upload.create(filename, base_64_content): Uploads a file created by the tool, making it available in Discourse. + * Parameters: + * filename (string): The desired name for the file (basename is used for security). + * base_64_content (string): Base64 encoded content of the file. + * Returns: { id: number, url: string, short_url: string } - Details of the created upload record. + * + * 5. chain + * Controls the execution flow. + * + * chain.setCustomRaw(raw): Sets the final raw content of the bot's post and immediately + * stops the tool execution chain. Useful for tools that directly + * generate the full response content (e.g., image generation tools attaching the image markdown). + * Parameters: + * raw (string): The raw Markdown content for the post. + * Returns: void + * + * 6. discourse + * Interacts with Discourse specific features. Access is generally performed as the SystemUser. + * + * discourse.search(params): Performs a Discourse search. + * Parameters: + * params (Object): Search parameters (e.g., { search_query: "keyword", with_private: true, max_results: 10 }). + * `with_private: true` searches across all posts visible to the SystemUser. `result_style: 'detailed'` is used by default. + * Returns: Object (Discourse search results structure, includes posts, topics, users etc.) + * + * discourse.getPost(post_id): Retrieves details for a specific post. + * Parameters: + * post_id (number): The ID of the post. + * Returns: Object (Post details including `raw`, nested `topic` object with ListableTopicSerializer structure) or null if not found/accessible. + * + * discourse.getTopic(topic_id): Retrieves details for a specific topic. + * Parameters: + * topic_id (number): The ID of the topic. + * Returns: Object (Topic details using ListableTopicSerializer structure) or null if not found/accessible. + * + * discourse.getUser(user_id_or_username): Retrieves details for a specific user. + * Parameters: + * user_id_or_username (number | string): The ID or username of the user. + * Returns: Object (User details using UserSerializer structure) or null if not found. + * + * discourse.getPersona(name): Gets an object representing another AI Persona configured on the site. + * Parameters: + * name (string): The name of the target persona. + * Returns: Object { respondTo: function(params) } or null if persona not found. + * respondTo(params): Instructs the target persona to generate a response within the current context (e.g., replying to the same post or chat message). + * Parameters: + * params (Object, optional): { instructions: string, whisper: boolean } + * Returns: { success: boolean, post_id?: number, post_number?: number, message_id?: number } or { error: string } + * + * discourse.createChatMessage(params): Creates a new message in a Discourse Chat channel. + * Parameters: + * params (Object): { channel_name: string, username: string, message: string } + * `channel_name` can be the channel name or slug. + * `username` specifies the user who should appear as the sender. The user must exist. + * The sending user must have permission to post in the channel. + * Returns: { success: boolean, message_id?: number, message?: string, created_at?: string } or { error: string } + * + * 7. context + * An object containing information about the environment where the tool is being run. + * Available properties depend on the invocation context, but may include: + * post_id (number): ID of the post triggering the tool (if in a Post context). + * topic_id (number): ID of the topic (if in a Post context). + * private_message (boolean): Whether the context is a private message (in Post context). + * message_id (number): ID of the chat message triggering the tool (if in Chat context). + * channel_id (number): ID of the chat channel (if in Chat context). + * user (Object): Details of the user invoking the tool/persona (structure may vary, often null or SystemUser details unless explicitly passed). + * participants (string): Comma-separated list of usernames in a PM (if applicable). + * // ... other potential context-specific properties added by the calling environment. + * + * Constraints + * + * Execution Time: ≤ 2000ms (default timeout in milliseconds) - This timer *pauses* during external HTTP requests or LLM calls initiated via `http.*` or `llm.generate`, but applies to the script's own processing time. + * Memory: ≤ 10MB (V8 heap limit) + * Stack Depth: ≤ 20 (Marshal stack depth limit for Ruby interop) + * HTTP Requests: ≤ 20 per execution + * Exceeding limits will result in errors or termination (e.g., timeout error, out-of-memory error, TooManyRequestsError). + * + * Security + * + * Sandboxed Environment: The script runs in a restricted V8 JavaScript environment (via MiniRacer). + * No direct access to browser or environment, browser globals (like `window` or `document`), or the host system's file system. + * Network requests are proxied through the Discourse backend, not made directly from the sandbox. + */ JS end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index bb4bb07e..3248d031 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -75,9 +75,9 @@ module DiscourseAi def force_tool_if_needed(prompt, context) return if prompt.tool_choice == :none - context[:chosen_tools] ||= [] + context.chosen_tools ||= [] forced_tools = persona.force_tool_use.map { |tool| tool.name } - force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } + force_tool = forced_tools.find { |name| !context.chosen_tools.include?(name) } if force_tool && persona.forced_tool_count > 0 user_turns = prompt.messages.select { |m| m[:type] == :user }.length @@ -85,7 +85,7 @@ module DiscourseAi end if force_tool - context[:chosen_tools] << force_tool + context.chosen_tools << force_tool prompt.tool_choice = force_tool else prompt.tool_choice = nil @@ -93,6 +93,9 @@ module DiscourseAi end def reply(context, &update_blk) + unless context.is_a?(BotContext) + raise ArgumentError, "context must be an instance of BotContext" + end llm = DiscourseAi::Completions::Llm.proxy(model) prompt = persona.craft_prompt(context, llm: llm) @@ -100,7 +103,7 @@ module DiscourseAi ongoing_chain = true raw_context = [] - user = context[:user] + user = context.user llm_kwargs = { user: user } llm_kwargs[:temperature] = persona.temperature if persona.temperature @@ -277,27 +280,15 @@ module DiscourseAi name: tool.name, } - if tool.standalone? - standalone_context = - context.dup.merge( - conversation_context: [ - context[:conversation_context].last, - tool_call_message, - tool_message, - ], - ) - prompt = persona.craft_prompt(standalone_context) - else - prompt.push(**tool_call_message) - prompt.push(**tool_message) - end + prompt.push(**tool_call_message) + prompt.push(**tool_message) raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name] raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] end def invoke_tool(tool, llm, cancel, context, &update_blk) - show_placeholder = !context[:skip_tool_details] && !tool.class.allow_partial_tool_calls? + show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls? update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder diff --git a/lib/ai_bot/bot_context.rb b/lib/ai_bot/bot_context.rb new file mode 100644 index 00000000..ca32262a --- /dev/null +++ b/lib/ai_bot/bot_context.rb @@ -0,0 +1,107 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class BotContext + attr_accessor :messages, + :topic_id, + :post_id, + :private_message, + :custom_instructions, + :user, + :skip_tool_details, + :participants, + :chosen_tools, + :message_id, + :channel_id, + :context_post_ids + + def initialize( + post: nil, + participants: nil, + user: nil, + skip_tool_details: nil, + messages: [], + custom_instructions: nil, + site_url: nil, + site_title: nil, + site_description: nil, + time: nil, + message_id: nil, + channel_id: nil, + context_post_ids: nil + ) + @participants = participants + @user = user + @skip_tool_details = skip_tool_details + @messages = messages + @custom_instructions = custom_instructions + + @message_id = message_id + @channel_id = channel_id + @context_post_ids = context_post_ids + + @site_url = site_url + @site_title = site_title + @site_description = site_description + @time = time + + if post + @post_id = post.id + @topic_id = post.topic_id + @private_message = post.topic.private_message? + @participants ||= post.topic.allowed_users.map(&:username).join(", ") if @private_message + @user = post.user + end + end + + # these are strings that can be safely interpolated into templates + TEMPLATE_PARAMS = %w[time site_url site_title site_description participants] + + def lookup_template_param(key) + public_send(key.to_sym) if TEMPLATE_PARAMS.include?(key) + end + + def time + @time ||= Time.zone.now + end + + def site_url + @site_url ||= Discourse.base_url + end + + def site_title + @site_title ||= SiteSetting.title + end + + def site_description + @site_description ||= SiteSetting.site_description + end + + def private_message? + @private_message + end + + def to_json + { + messages: @messages, + topic_id: @topic_id, + post_id: @post_id, + private_message: @private_message, + custom_instructions: @custom_instructions, + username: @user&.username, + user_id: @user&.id, + participants: @participants, + chosen_tools: @chosen_tools, + message_id: @message_id, + channel_id: @channel_id, + context_post_ids: @context_post_ids, + site_url: @site_url, + site_title: @site_title, + site_description: @site_description, + skip_tool_details: @skip_tool_details, + } + end + end + end +end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 33c2907e..0d6745a3 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -163,7 +163,7 @@ module DiscourseAi def craft_prompt(context, llm: nil) system_insts = system_prompt.gsub(/\{(\w+)\}/) do |match| - found = context[match[1..-2].to_sym] + found = context.lookup_template_param(match[1..-2]) found.nil? ? match : found.to_s end @@ -180,16 +180,16 @@ module DiscourseAi ) end - if context[:custom_instructions].present? + if context.custom_instructions.present? prompt_insts << "\n" - prompt_insts << context[:custom_instructions] + prompt_insts << context.custom_instructions end fragments_guidance = rag_fragments_prompt( - context[:conversation_context].to_a, + context.messages, llm: question_consolidator_llm, - user: context[:user], + user: context.user, )&.strip prompt_insts << fragments_guidance if fragments_guidance.present? @@ -197,9 +197,9 @@ module DiscourseAi prompt = DiscourseAi::Completions::Prompt.new( prompt_insts, - messages: context[:conversation_context].to_a, - topic_id: context[:topic_id], - post_id: context[:post_id], + messages: context.messages, + topic_id: context.topic_id, + post_id: context.post_id, ) prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 59533efd..685683fd 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -227,90 +227,17 @@ module DiscourseAi schedule_bot_reply(post) if can_attach?(post) end - def conversation_context(post, style: nil) - # Pay attention to the `post_number <= ?` here. - # We want to inject the last post as context because they are translated differently. - - # also setting default to 40, allowing huge contexts costs lots of tokens - max_posts = 40 - if bot.persona.class.respond_to?(:max_context_posts) - max_posts = bot.persona.class.max_context_posts || 40 - end - - post_types = [Post.types[:regular]] - post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper] - - context = - post - .topic - .posts - .joins(:user) - .joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id") - .where("post_number <= ?", post.post_number) - .order("post_number desc") - .where("post_type in (?)", post_types) - .limit(max_posts) - .pluck( - "posts.raw", - "users.username", - "post_custom_prompts.custom_prompt", - "( - SELECT array_agg(ref.upload_id) - FROM upload_references ref - WHERE ref.target_type = 'Post' AND ref.target_id = posts.id - ) as upload_ids", - ) - - builder = DiscourseAi::Completions::PromptMessagesBuilder.new - builder.topic = post.topic - - context.reverse_each do |raw, username, custom_prompt, upload_ids| - custom_prompt_translation = - Proc.new do |message| - # We can't keep backwards-compatibility for stored functions. - # Tool syntax requires a tool_call_id which we don't have. - if message[2] != "function" - custom_context = { - content: message[0], - type: message[2].present? ? message[2].to_sym : :model, - } - - custom_context[:id] = message[1] if custom_context[:type] != :model - custom_context[:name] = message[3] if message[3] - - thinking = message[4] - custom_context[:thinking] = thinking if thinking - - builder.push(**custom_context) - end - end - - if custom_prompt.present? - custom_prompt.each(&custom_prompt_translation) - else - context = { - content: raw, - type: (available_bot_usernames.include?(username) ? :model : :user), - } - - context[:id] = username if context[:type] == :user - - if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled - context[:upload_ids] = upload_ids.compact - end - - builder.push(**context) - end - end - - builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic)) - end - def title_playground(post, user) - context = conversation_context(post) + messages = + DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post( + post, + max_posts: 5, + bot_usernames: available_bot_usernames, + include_uploads: bot.persona.class.vision_enabled, + ) bot - .get_updated_title(context, post, user) + .get_updated_title(messages, post, user) .tap do |new_title| PostRevisor.new(post.topic.first_post, post.topic).revise!( bot.bot_user, @@ -326,83 +253,6 @@ module DiscourseAi ) end - def chat_context(message, channel, persona_user, context_post_ids) - has_vision = bot.persona.class.vision_enabled - include_thread_titles = !channel.direct_message_channel? && !message.thread_id - - current_id = message.id - if !channel.direct_message_channel? - # we are interacting via mentions ... strip mention - instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip - end - - messages = nil - - max_messages = 40 - if bot.persona.class.respond_to?(:max_context_posts) - max_messages = bot.persona.class.max_context_posts || 40 - end - - if !message.thread_id && channel.direct_message_channel? - messages = [message] - elsif !channel.direct_message_channel? && !message.thread_id - messages = - Chat::Message - .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id") - .where(chat_channel_id: channel.id) - .where( - "chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id", - ) - .order(id: :desc) - .limit(max_messages) - .to_a - .reverse - end - - messages ||= - ChatSDK::Thread.last_messages( - thread_id: message.thread_id, - guardian: Discourse.system_user.guardian, - page_size: max_messages, - ) - - builder = DiscourseAi::Completions::PromptMessagesBuilder.new - - guardian = Guardian.new(message.user) - if context_post_ids - builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision) - end - - messages.each do |m| - # restore stripped message - m.message = instruction_message if m.id == current_id && instruction_message - - if available_bot_user_ids.include?(m.user_id) - builder.push(type: :model, content: m.message) - else - upload_ids = nil - upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present? - mapped_message = m.message - - thread_title = nil - thread_title = m.thread&.title if include_thread_titles && m.thread_id - mapped_message = "(#{thread_title})\n#{m.message}" if thread_title - - builder.push( - type: :user, - content: mapped_message, - name: m.user.username, - upload_ids: upload_ids, - ) - end - end - - builder.to_a( - limit: max_messages, - style: channel.direct_message_channel? ? :chat_with_context : :chat, - ) - end - def reply_to_chat_message(message, channel, context_post_ids) persona_user = User.find(bot.persona.class.user_id) @@ -410,10 +260,32 @@ module DiscourseAi context_post_ids = nil if !channel.direct_message_channel? + max_chat_messages = 40 + if bot.persona.class.respond_to?(:max_context_posts) + max_chat_messages = bot.persona.class.max_context_posts || 40 + end + + if !channel.direct_message_channel? + # we are interacting via mentions ... strip mention + instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip + end + context = - get_context( - participants: participants.join(", "), - conversation_context: chat_context(message, channel, persona_user, context_post_ids), + BotContext.new( + participants: participants, + message_id: message.id, + channel_id: channel.id, + context_post_ids: context_post_ids, + messages: + DiscourseAi::Completions::PromptMessagesBuilder.messages_from_chat( + message, + channel: channel, + context_post_ids: context_post_ids, + include_uploads: bot.persona.class.vision_enabled, + max_messages: max_chat_messages, + bot_user_ids: available_bot_user_ids, + instruction_message: instruction_message, + ), user: message.user, skip_tool_details: true, ) @@ -460,22 +332,6 @@ module DiscourseAi reply end - def get_context(participants:, conversation_context:, user:, skip_tool_details: nil) - result = { - site_url: Discourse.base_url, - site_title: SiteSetting.title, - site_description: SiteSetting.site_description, - time: Time.zone.now, - participants: participants, - conversation_context: conversation_context, - user: user, - } - - result[:skip_tool_details] = true if skip_tool_details - - result - end - def reply_to( post, custom_instructions: nil, @@ -509,16 +365,25 @@ module DiscourseAi end ) + # safeguard + max_context_posts = 40 + if bot.persona.class.respond_to?(:max_context_posts) + max_context_posts = bot.persona.class.max_context_posts || 40 + end + context = - get_context( - participants: post.topic.allowed_users.map(&:username).join(", "), - conversation_context: conversation_context(post, style: context_style), - user: post.user, + BotContext.new( + post: post, + custom_instructions: custom_instructions, + messages: + DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post( + post, + style: context_style, + max_posts: max_context_posts, + include_uploads: bot.persona.class.vision_enabled, + bot_usernames: available_bot_usernames, + ), ) - context[:post_id] = post.id - context[:topic_id] = post.topic_id - context[:private_message] = post.topic.private_message? - context[:custom_instructions] = custom_instructions reply_user = bot.bot_user if bot.persona.class.respond_to?(:user_id) @@ -562,7 +427,7 @@ module DiscourseAi Discourse.redis.setex(redis_stream_key, 60, 1) end - context[:skip_tool_details] ||= !bot.persona.class.tool_details + context.skip_tool_details ||= !bot.persona.class.tool_details post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply diff --git a/lib/ai_bot/tool_runner.rb b/lib/ai_bot/tool_runner.rb index df785749..8a42a494 100644 --- a/lib/ai_bot/tool_runner.rb +++ b/lib/ai_bot/tool_runner.rb @@ -13,7 +13,13 @@ module DiscourseAi MARSHAL_STACK_DEPTH = 20 MAX_HTTP_REQUESTS = 20 - def initialize(parameters:, llm:, bot_user:, context: {}, tool:, timeout: nil) + def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil) + if context && !context.is_a?(DiscourseAi::AiBot::BotContext) + raise ArgumentError, "context must be a BotContext object" + end + + context ||= DiscourseAi::AiBot::BotContext.new + @parameters = parameters @llm = llm @bot_user = bot_user @@ -99,7 +105,7 @@ module DiscourseAi }, }; - const context = #{JSON.generate(@context)}; + const context = #{JSON.generate(@context.to_json)}; function details() { return ""; }; JS @@ -240,13 +246,13 @@ module DiscourseAi def llm_user @llm_user ||= begin - @context[:llm_user] || post&.user || @bot_user + post&.user || @bot_user end end def post return @post if defined?(@post) - post_id = @context[:post_id] + post_id = @context.post_id @post = post_id && Post.find_by(id: post_id) end @@ -336,8 +342,8 @@ module DiscourseAi bot = DiscourseAi::AiBot::Bot.as(@bot_user || persona.user, persona: persona) playground = DiscourseAi::AiBot::Playground.new(bot) - if @context[:post_id] - post = Post.find_by(id: @context[:post_id]) + if @context.post_id + post = Post.find_by(id: @context.post_id) return { error: "Post not found" } if post.nil? reply_post = @@ -354,13 +360,13 @@ module DiscourseAi else return { error: "Failed to create reply" } end - elsif @context[:message_id] && @context[:channel_id] - message = Chat::Message.find_by(id: @context[:message_id]) - channel = Chat::Channel.find_by(id: @context[:channel_id]) + elsif @context.message_id && @context.channel_id + message = Chat::Message.find_by(id: @context.message_id) + channel = Chat::Channel.find_by(id: @context.channel_id) return { error: "Message or channel not found" } if message.nil? || channel.nil? reply = - playground.reply_to_chat_message(message, channel, @context[:context_post_ids]) + playground.reply_to_chat_message(message, channel, @context.context_post_ids) if reply return { success: true, message_id: reply.id } @@ -457,7 +463,7 @@ module DiscourseAi UploadCreator.new( file, filename, - for_private_message: @context[:private_message], + for_private_message: @context.private_message, ).create_for(@bot_user.id) { id: upload.id, short_url: upload.short_url, url: upload.url } diff --git a/lib/ai_bot/tools/create_artifact.rb b/lib/ai_bot/tools/create_artifact.rb index b9175e8d..54221b40 100644 --- a/lib/ai_bot/tools/create_artifact.rb +++ b/lib/ai_bot/tools/create_artifact.rb @@ -108,7 +108,7 @@ module DiscourseAi end def invoke - post = Post.find_by(id: context[:post_id]) + post = Post.find_by(id: context.post_id) return error_response("No post context found") unless post partial_response = +"" diff --git a/lib/ai_bot/tools/dall_e.rb b/lib/ai_bot/tools/dall_e.rb index 7389a3b5..4564cfbb 100644 --- a/lib/ai_bot/tools/dall_e.rb +++ b/lib/ai_bot/tools/dall_e.rb @@ -111,7 +111,7 @@ module DiscourseAi UploadCreator.new( file, "image.png", - for_private_message: context[:private_message], + for_private_message: context.private_message?, ).create_for(bot_user.id), } end diff --git a/lib/ai_bot/tools/image.rb b/lib/ai_bot/tools/image.rb index 72e575c3..34e5e3f2 100644 --- a/lib/ai_bot/tools/image.rb +++ b/lib/ai_bot/tools/image.rb @@ -131,7 +131,7 @@ module DiscourseAi UploadCreator.new( file, "image.png", - for_private_message: context[:private_message], + for_private_message: context.private_message, ).create_for(bot_user.id), seed: image[:seed], } diff --git a/lib/ai_bot/tools/read.rb b/lib/ai_bot/tools/read.rb index 8bb3afc5..d7a0186f 100644 --- a/lib/ai_bot/tools/read.rb +++ b/lib/ai_bot/tools/read.rb @@ -48,7 +48,7 @@ module DiscourseAi def invoke not_found = { topic_id: topic_id, description: "Topic not found" } - guardian = Guardian.new(context[:user]) if options[:read_private] && context[:user] + guardian = Guardian.new(context.user) if options[:read_private] && context.user guardian ||= Guardian.new @title = "" diff --git a/lib/ai_bot/tools/read_artifact.rb b/lib/ai_bot/tools/read_artifact.rb index 7fb5646c..98b7944f 100644 --- a/lib/ai_bot/tools/read_artifact.rb +++ b/lib/ai_bot/tools/read_artifact.rb @@ -59,7 +59,7 @@ module DiscourseAi end def post - @post ||= Post.find_by(id: context[:post_id]) + @post ||= Post.find_by(id: context.post_id) end def handle_discourse_artifact(uri) diff --git a/lib/ai_bot/tools/search.rb b/lib/ai_bot/tools/search.rb index a14e77d6..402eca05 100644 --- a/lib/ai_bot/tools/search.rb +++ b/lib/ai_bot/tools/search.rb @@ -130,7 +130,7 @@ module DiscourseAi after: parameters[:after], status: parameters[:status], max_results: max_results, - current_user: options[:search_private] ? context[:user] : nil, + current_user: options[:search_private] ? context.user : nil, ) @last_num_results = results[:rows]&.length || 0 diff --git a/lib/ai_bot/tools/summarize.rb b/lib/ai_bot/tools/summarize.rb index 214d45fa..5076635f 100644 --- a/lib/ai_bot/tools/summarize.rb +++ b/lib/ai_bot/tools/summarize.rb @@ -40,10 +40,6 @@ module DiscourseAi false end - def standalone? - true - end - def custom_raw @last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found") end diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index ff95f16e..ba0bc69c 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -56,14 +56,17 @@ module DiscourseAi persona_options: {}, bot_user:, llm:, - context: {} + context: nil ) @parameters = parameters @tool_call_id = tool_call_id @persona_options = persona_options @bot_user = bot_user @llm = llm - @context = context + @context = context.nil? ? DiscourseAi::AiBot::BotContext.new(messages: []) : context + if !@context.is_a?(DiscourseAi::AiBot::BotContext) + raise ArgumentError, "context must be a DiscourseAi::AiBot::Context" + end end def name @@ -108,10 +111,6 @@ module DiscourseAi true end - def standalone? - false - end - protected def fetch_default_branch(repo) diff --git a/lib/ai_bot/tools/update_artifact.rb b/lib/ai_bot/tools/update_artifact.rb index 4b62bd01..2c9e544b 100644 --- a/lib/ai_bot/tools/update_artifact.rb +++ b/lib/ai_bot/tools/update_artifact.rb @@ -39,7 +39,7 @@ module DiscourseAi def self.inject_prompt(prompt:, context:, persona:) return if persona.options["do_not_echo_artifact"].to_s == "true" # we inject the current artifact content into the last user message - if topic_id = context[:topic_id] + if topic_id = context.topic_id posts = Post.where(topic_id: topic_id) artifact = AiArtifact.order("id desc").where(post: posts).first if artifact @@ -113,7 +113,7 @@ module DiscourseAi end def invoke - post = Post.find_by(id: context[:post_id]) + post = Post.find_by(id: context.post_id) return error_response("No post context found") unless post artifact = AiArtifact.find_by(id: parameters[:artifact_id]) diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index 8de1bcf8..c6716aff 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -188,9 +188,10 @@ module DiscourseAi messages: [ { type: :user, - content: + content: [ "Describe this image in a single sentence#{custom_locale_instructions(user)}", - upload_ids: [upload.id], + { upload_id: upload.id }, + ], }, ], ) diff --git a/lib/ai_moderation/spam_scanner.rb b/lib/ai_moderation/spam_scanner.rb index 8569c8e8..e4ed348d 100644 --- a/lib/ai_moderation/spam_scanner.rb +++ b/lib/ai_moderation/spam_scanner.rb @@ -187,7 +187,10 @@ module DiscourseAi prompt = DiscourseAi::Completions::Prompt.new(system_prompt) args = { type: :user, content: context } upload_ids = post.upload_ids - args[:upload_ids] = upload_ids.take(3) if upload_ids.present? + if upload_ids.present? + args[:content] = [args[:content]] + upload_ids.take(3).each { |upload_id| args[:content] << { upload_id: upload_id } } + end prompt.push(**args) prompt end diff --git a/lib/automation/llm_tool_triage.rb b/lib/automation/llm_tool_triage.rb index 978119a2..7d7622ad 100644 --- a/lib/automation/llm_tool_triage.rb +++ b/lib/automation/llm_tool_triage.rb @@ -7,11 +7,7 @@ module DiscourseAi return if !tool return if !tool.parameters.blank? - context = { - post_id: post.id, - automation_id: automation&.id, - automation_name: automation&.name, - } + context = DiscourseAi::AiBot::BotContext.new(post: post) runner = tool.runner({}, llm: nil, bot_user: Discourse.system_user, context: context) runner.invoke diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index c3d0d5c1..640e7046 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -42,7 +42,12 @@ module DiscourseAi content = llm.tokenizer.truncate(content, max_post_tokens) if max_post_tokens.present? - prompt.push(type: :user, content: content, upload_ids: post.upload_ids) + if post.upload_ids.present? + content = [content] + content.concat(post.upload_ids.map { |upload_id| { upload_id: upload_id } }) + end + + prompt.push(type: :user, content: content) result = nil diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 578cbd2b..d2d7245c 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -17,13 +17,13 @@ module DiscourseAi llm_model.provider == "open_ai" || llm_model.provider == "azure" end - def translate + def embed_user_ids? + return @embed_user_ids if defined?(@embed_user_ids) + @embed_user_ids = prompt.messages.any? do |m| m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) end - - super end def max_prompt_tokens @@ -102,35 +102,47 @@ module DiscourseAi end def user_msg(msg) - user_message = { role: "user", content: msg[:content] } + content_array = [] + + user_message = { role: "user" } if msg[:id] - if @embed_user_ids - user_message[:content] = "#{msg[:id]}: #{msg[:content]}" + if embed_user_ids? + content_array << "#{msg[:id]}: " else user_message[:name] = msg[:id] end end - user_message[:content] = inline_images(user_message[:content], msg) if vision_support? + content_array << msg[:content] + + content_array = + to_encoded_content_array( + content: content_array.flatten, + image_encoder: ->(details) { image_node(details) }, + text_encoder: ->(text) { { type: "text", text: text } }, + allow_vision: vision_support?, + ) + + user_message[:content] = no_array_if_only_text(content_array) user_message end - def inline_images(content, message) - encoded_uploads = prompt.encoded_uploads(message) - return content if encoded_uploads.blank? + def no_array_if_only_text(content_array) + if content_array.size == 1 && content_array.first[:type] == "text" + content_array.first[:text] + else + content_array + end + end - content_w_imgs = - encoded_uploads.reduce([]) do |memo, details| - memo << { - type: "image_url", - image_url: { - url: "data:#{details[:mime_type]};base64,#{details[:base64]}", - }, - } - end - - content_w_imgs << { type: "text", text: message[:content] } + def image_node(details) + { + type: "image_url", + image_url: { + url: "data:#{details[:mime_type]};base64,#{details[:base64]}", + }, + } end def per_message_overhead diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 06fbe102..1cea2215 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -87,9 +87,9 @@ module DiscourseAi end def model_msg(msg) - if msg[:thinking] || msg[:redacted_thinking_signature] - content_array = [] + content_array = [] + if msg[:thinking] || msg[:redacted_thinking_signature] if msg[:thinking] content_array << { type: "thinking", @@ -104,13 +104,19 @@ module DiscourseAi data: msg[:redacted_thinking_signature], } end - - content_array << { type: "text", text: msg[:content] } - - { role: "assistant", content: content_array } - else - { role: "assistant", content: msg[:content] } end + + # other encoder is used to pass through thinking + content_array = + to_encoded_content_array( + content: [content_array, msg[:content]].flatten, + image_encoder: ->(details) {}, + text_encoder: ->(text) { { type: "text", text: text } }, + other_encoder: ->(details) { details }, + allow_vision: false, + ) + + { role: "assistant", content: no_array_if_only_text(content_array) } end def system_msg(msg) @@ -124,31 +130,39 @@ module DiscourseAi end def user_msg(msg) - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - content = inline_images(content, msg) if vision_support? + content_array = [] + content_array << "#{msg[:id]}: " if msg[:id] + content_array.concat([msg[:content]].flatten) - { role: "user", content: content } + content_array = + to_encoded_content_array( + content: content_array, + image_encoder: ->(details) { image_node(details) }, + text_encoder: ->(text) { { type: "text", text: text } }, + allow_vision: vision_support?, + ) + + { role: "user", content: no_array_if_only_text(content_array) } end - def inline_images(content, message) - encoded_uploads = prompt.encoded_uploads(message) - return content if encoded_uploads.blank? + # keeping our payload as backward compatible as possible + def no_array_if_only_text(content_array) + if content_array.length == 1 && content_array.first[:type] == "text" + content_array.first[:text] + else + content_array + end + end - content_w_imgs = - encoded_uploads.reduce([]) do |memo, details| - memo << { - source: { - type: "base64", - data: details[:base64], - media_type: details[:mime_type], - }, - type: "image", - } - end - - content_w_imgs << { type: "text", text: content } + def image_node(details) + { + source: { + type: "base64", + data: details[:base64], + media_type: details[:mime_type], + }, + type: "image", + } end end end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 25561433..6af4b10f 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -110,9 +110,9 @@ module DiscourseAi end def user_msg(msg) - user_message = { role: "USER", message: msg[:content] } - user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] - + content = prompt.text_only(msg) + user_message = { role: "USER", message: content } + user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id] user_message end end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 2a335a12..5d84f6d5 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -227,6 +227,38 @@ module DiscourseAi msg = msg.merge(content: new_content) user_msg(msg) end + + def to_encoded_content_array( + content:, + image_encoder:, + text_encoder:, + other_encoder: nil, + allow_vision: + ) + content = [content] if !content.is_a?(Array) + + current_string = +"" + result = [] + + content.each do |c| + if c.is_a?(String) + current_string << c + elsif c.is_a?(Hash) && c.key?(:upload_id) && allow_vision + if !current_string.empty? + result << text_encoder.call(current_string) + current_string = +"" + end + encoded = prompt.encode_upload(c[:upload_id]) + result << image_encoder.call(encoded) if encoded + elsif other_encoder + encoded = other_encoder.call(c) + result << encoded if encoded + end + end + + result << text_encoder.call(current_string) if !current_string.empty? + result + end end end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 563ed88b..050fabd2 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -106,28 +106,29 @@ module DiscourseAi end def user_msg(msg) - if beta_api? - # support new format with multiple parts - result = { role: "user", parts: [{ text: msg[:content] }] } - return result unless vision_support? + content_array = [] + content_array << "#{msg[:id]}: " if msg[:id] - upload_parts = uploaded_parts(msg) - result[:parts].concat(upload_parts) if upload_parts.present? - result + content_array << msg[:content] + content_array.flatten! + + content_array = + to_encoded_content_array( + content: content_array, + image_encoder: ->(details) { image_node(details) }, + text_encoder: ->(text) { { text: text } }, + allow_vision: vision_support? && beta_api?, + ) + + if beta_api? + { role: "user", parts: content_array } else - { role: "user", parts: { text: msg[:content] } } + { role: "user", parts: content_array.first } end end - def uploaded_parts(message) - encoded_uploads = prompt.encoded_uploads(message) - result = [] - if encoded_uploads.present? - encoded_uploads.each do |details| - result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } } - end - end - result + def image_node(details) + { inlineData: { mimeType: details[:mime_type], data: details[:base64] } } end def tool_call_msg(msg) diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb index aa184a7a..b078e79d 100644 --- a/lib/completions/dialects/nova.rb +++ b/lib/completions/dialects/nova.rb @@ -155,7 +155,7 @@ module DiscourseAi end end - { role: "user", content: msg[:content], images: images } + { role: "user", content: prompt.text_only(msg), images: images } end def model_msg(msg) diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 4546c827..fe31bc1d 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -69,7 +69,7 @@ module DiscourseAi end def user_msg(msg) - user_message = { role: "user", content: msg[:content] } + user_message = { role: "user", content: prompt.text_only(msg) } encoded_uploads = prompt.encoded_uploads(msg) if encoded_uploads.present? diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index 2d648ac1..d6519822 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -3,7 +3,7 @@ module DiscourseAi module Completions module Dialects - class OpenAiCompatible < Dialect + class OpenAiCompatible < ChatGpt class << self def can_translate?(_llm_model) # fallback dialect @@ -43,58 +43,6 @@ module DiscourseAi translated.unshift(user_msg) end - - private - - def system_msg(msg) - msg = { role: "system", content: msg[:content] } - - if tools_dialect.instructions.present? - msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" - end - - msg - end - - def model_msg(msg) - { role: "assistant", content: msg[:content] } - end - - def tool_call_msg(msg) - translated = tools_dialect.from_raw_tool_call(msg) - { role: "assistant", content: translated } - end - - def tool_msg(msg) - translated = tools_dialect.from_raw_tool(msg) - { role: "user", content: translated } - end - - def user_msg(msg) - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - - message = { role: "user", content: content } - - message[:content] = inline_images(message[:content], msg) if vision_support? - - message - end - - def inline_images(content, message) - encoded_uploads = prompt.encoded_uploads(message) - return content if encoded_uploads.blank? - - encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details| - memo << { - type: "image_url", - image_url: { - url: "data:#{details[:mime_type]};base64,#{details[:base64]}", - }, - } - end - end end end end diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 79e71c06..660b1ca5 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -89,7 +89,6 @@ module DiscourseAi content:, id: nil, name: nil, - upload_ids: nil, thinking: nil, thinking_signature: nil, redacted_thinking_signature: nil @@ -98,7 +97,6 @@ module DiscourseAi new_message = { type: type, content: content } new_message[:name] = name.to_s if name new_message[:id] = id.to_s if id - new_message[:upload_ids] = upload_ids if upload_ids new_message[:thinking] = thinking if thinking new_message[:thinking_signature] = thinking_signature if thinking_signature new_message[ @@ -115,11 +113,44 @@ module DiscourseAi tools.present? end - # helper method to get base64 encoded uploads - # at the correct dimentions def encoded_uploads(message) - return [] if message[:upload_ids].blank? - UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels) + if message[:content].is_a?(Array) + upload_ids = + message[:content] + .map do |content| + content[:upload_id] if content.is_a?(Hash) && content.key?(:upload_id) + end + .compact + if !upload_ids.empty? + return UploadEncoder.encode(upload_ids: upload_ids, max_pixels: max_pixels) + end + end + + [] + end + + def text_only(message) + if message[:content].is_a?(Array) + message[:content].map { |element| element if element.is_a?(String) }.compact.join + else + message[:content] + end + end + + def encode_upload(upload_id) + UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first + end + + def content_with_encoded_uploads(content) + return [content] unless content.is_a?(Array) + + content.map do |c| + if c.is_a?(Hash) && c.key?(:upload_id) + encode_upload(c[:upload_id]) + else + c + end + end end def ==(other) @@ -150,7 +181,6 @@ module DiscourseAi content id name - upload_ids thinking thinking_signature redacted_thinking_signature @@ -159,15 +189,17 @@ module DiscourseAi raise ArgumentError, "message contains invalid keys: #{invalid_keys}" end - if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array) - raise ArgumentError, "upload_ids must be an array of ids" + if message[:content].is_a?(Array) + message[:content].each do |content| + if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id]) + raise ArgumentError, "Array message content must be a string or {upload_id: ...} " + end + end + else + if !message[:content].is_a?(String) + raise ArgumentError, "Message content must be a string or an array" + end end - - if message[:upload_ids].present? && message[:type] != :user - raise ArgumentError, "upload_ids are only supported for users" - end - - raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String) end def validate_turn(last_turn, new_turn) diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb index 26fb24fa..b50b5b67 100644 --- a/lib/completions/prompt_messages_builder.rb +++ b/lib/completions/prompt_messages_builder.rb @@ -9,6 +9,154 @@ module DiscourseAi attr_reader :chat_context_post_upload_ids attr_accessor :topic + def self.messages_from_chat( + message, + channel:, + context_post_ids:, + max_messages:, + include_uploads:, + bot_user_ids:, + instruction_message: nil + ) + include_thread_titles = !channel.direct_message_channel? && !message.thread_id + + current_id = message.id + messages = nil + + if !message.thread_id && channel.direct_message_channel? + messages = [message] + elsif !channel.direct_message_channel? && !message.thread_id + messages = + Chat::Message + .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id") + .where(chat_channel_id: channel.id) + .where( + "chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id", + ) + .order(id: :desc) + .limit(max_messages) + .to_a + .reverse + end + + messages ||= + ChatSDK::Thread.last_messages( + thread_id: message.thread_id, + guardian: Discourse.system_user.guardian, + page_size: max_messages, + ) + + builder = new + + guardian = Guardian.new(message.user) + if context_post_ids + builder.set_chat_context_posts( + context_post_ids, + guardian, + include_uploads: include_uploads, + ) + end + + messages.each do |m| + # restore stripped message + m.message = instruction_message if m.id == current_id && instruction_message + + if bot_user_ids.include?(m.user_id) + builder.push(type: :model, content: m.message) + else + upload_ids = nil + upload_ids = m.uploads.map(&:id) if include_uploads && m.uploads.present? + mapped_message = m.message + + thread_title = nil + thread_title = m.thread&.title if include_thread_titles && m.thread_id + mapped_message = "(#{thread_title})\n#{m.message}" if thread_title + + builder.push( + type: :user, + content: mapped_message, + name: m.user.username, + upload_ids: upload_ids, + ) + end + end + + builder.to_a( + limit: max_messages, + style: channel.direct_message_channel? ? :chat_with_context : :chat, + ) + end + + def self.messages_from_post(post, style: nil, max_posts:, bot_usernames:, include_uploads:) + # Pay attention to the `post_number <= ?` here. + # We want to inject the last post as context because they are translated differently. + + post_types = [Post.types[:regular]] + post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper] + + context = + post + .topic + .posts + .joins(:user) + .joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id") + .where("post_number <= ?", post.post_number) + .order("post_number desc") + .where("post_type in (?)", post_types) + .limit(max_posts) + .pluck( + "posts.raw", + "users.username", + "post_custom_prompts.custom_prompt", + "( + SELECT array_agg(ref.upload_id) + FROM upload_references ref + WHERE ref.target_type = 'Post' AND ref.target_id = posts.id + ) as upload_ids", + ) + + builder = new + builder.topic = post.topic + + context.reverse_each do |raw, username, custom_prompt, upload_ids| + custom_prompt_translation = + Proc.new do |message| + # We can't keep backwards-compatibility for stored functions. + # Tool syntax requires a tool_call_id which we don't have. + if message[2] != "function" + custom_context = { + content: message[0], + type: message[2].present? ? message[2].to_sym : :model, + } + + custom_context[:id] = message[1] if custom_context[:type] != :model + custom_context[:name] = message[3] if message[3] + + thinking = message[4] + custom_context[:thinking] = thinking if thinking + + builder.push(**custom_context) + end + end + + if custom_prompt.present? + custom_prompt.each(&custom_prompt_translation) + else + context = { content: raw, type: (bot_usernames.include?(username) ? :model : :user) } + + context[:id] = username if context[:type] == :user + + if upload_ids.present? && context[:type] == :user && include_uploads + context[:upload_ids] = upload_ids.compact + end + + builder.push(**context) + end + end + + builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic)) + end + def initialize @raw_messages = [] end @@ -68,12 +216,19 @@ module DiscourseAi if message[:type] == :user old_name = last_message.delete(:name) - last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name + last_message[:content] = ["#{old_name}: ", last_message[:content]].flatten if old_name new_content = message[:content] - new_content = "#{message[:name]}: #{new_content}" if message[:name] + new_content = ["#{message[:name]}: ", new_content].flatten if message[:name] - last_message[:content] += "\n#{new_content}" + if !last_message[:content].is_a?(Array) + last_message[:content] = [last_message[:content]] + end + last_message[:content].concat(["\n", new_content].flatten) + + compressed = + compress_messages_buffer(last_message[:content], max_uploads: MAX_TOPIC_UPLOADS) + last_message[:content] = compressed else last_message[:content] = message[:content] end @@ -111,9 +266,9 @@ module DiscourseAi end raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array) + content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids message = { type: type, content: content } message[:name] = name.to_s if name - message[:upload_ids] = upload_ids if upload_ids message[:id] = id.to_s if id if thinking message[:thinking] = thinking["thinking"] if thinking["thinking"] @@ -132,67 +287,62 @@ module DiscourseAi def topic_array raw_messages = @raw_messages.dup - user_content = +"You are operating in a Discourse forum.\n\n" + content_array = [] + content_array << "You are operating in a Discourse forum.\n\n" if @topic if @topic.private_message? - user_content << "Private message info.\n" + content_array << "Private message info.\n" else - user_content << "Topic information:\n" + content_array << "Topic information:\n" end - user_content << "- URL: #{@topic.url}\n" - user_content << "- Title: #{@topic.title}\n" + content_array << "- URL: #{@topic.url}\n" + content_array << "- Title: #{@topic.title}\n" if SiteSetting.tagging_enabled tags = @topic.tags.pluck(:name) tags -= DiscourseTagging.hidden_tag_names if tags.present? - user_content << "- Tags: #{tags.join(", ")}\n" if tags.present? + content_array << "- Tags: #{tags.join(", ")}\n" if tags.present? end if !@topic.private_message? - user_content << "- Category: #{@topic.category.name}\n" if @topic.category + content_array << "- Category: #{@topic.category.name}\n" if @topic.category end - user_content << "- Number of replies: #{@topic.posts_count - 1}\n\n" + content_array << "- Number of replies: #{@topic.posts_count - 1}\n\n" end last_user_message = raw_messages.pop - upload_ids = [] if raw_messages.present? - user_content << "Here is the conversation so far:\n" + content_array << "Here is the conversation so far:\n" raw_messages.each do |message| - user_content << "#{message[:name] || "User"}: #{message[:content]}\n" - upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present? + content_array << "#{message[:name] || "User"}: " + content_array << message[:content] + content_array << "\n\n" end end if last_user_message - user_content << "You are responding to #{last_user_message[:name] || "User"} who just said:\n #{last_user_message[:content]}" - if last_user_message[:upload_ids].present? - upload_ids.concat(last_user_message[:upload_ids]) - end + content_array << "You are responding to #{last_user_message[:name] || "User"} who just said:\n" + content_array << last_user_message[:content] end - user_message = { type: :user, content: user_content } + content_array = + compress_messages_buffer(content_array.flatten, max_uploads: MAX_TOPIC_UPLOADS) - if upload_ids.present? - user_message[:upload_ids] = upload_ids[-MAX_TOPIC_UPLOADS..-1] || upload_ids - end + user_message = { type: :user, content: content_array } [user_message] end def chat_array(limit:) if @raw_messages.length > 1 - buffer = - +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{" - - upload_ids = [] + buffer = [ + +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{", + ] @raw_messages[0..-2].each do |message| buffer << "\n" - upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present? - if message[:type] == :user buffer << "#{message[:name] || "User"}: " else @@ -209,16 +359,44 @@ module DiscourseAi end last_message = @raw_messages[-1] - buffer << "#{last_message[:name] || "User"}: #{last_message[:content]} " + buffer << "#{last_message[:name] || "User"}: " + buffer << last_message[:content] + + buffer = compress_messages_buffer(buffer.flatten, max_uploads: MAX_CHAT_UPLOADS) message = { type: :user, content: buffer } - upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present? - - message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] || - upload_ids if upload_ids.present? - [message] end + + # caps uploads to maximum uploads allowed in message stream + # and concats string elements + def compress_messages_buffer(buffer, max_uploads:) + compressed = [] + current_text = +"" + upload_count = 0 + + buffer.each do |item| + if item.is_a?(String) + current_text << item + elsif item.is_a?(Hash) + compressed << current_text if current_text.present? + compressed << item + current_text = +"" + upload_count += 1 + end + end + + compressed << current_text if current_text.present? + + if upload_count > max_uploads + counter = max_uploads - upload_count + compressed.delete_if { |item| item.is_a?(Hash) && (counter += 1) > 0 } + end + + compressed = compressed[0] if compressed.length == 1 && compressed[0].is_a?(String) + + compressed + end end end end diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index a94ae0e6..860fe653 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(context.image_generation_scenario).to eq( { messages: [ - { role: "user", parts: [{ text: "draw a cat" }] }, + { role: "user", parts: [{ text: "user1: draw a cat" }] }, { role: "model", parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }], @@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do ], }, { role: "model", parts: { text: "Ok." } }, - { role: "user", parts: [{ text: "draw another cat" }] }, + { role: "user", parts: [{ text: "user1: draw another cat" }] }, ], system_instruction: context.system_insts, }, @@ -52,12 +52,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(context.multi_turn_scenario).to eq( { messages: [ - { role: "user", parts: [{ text: "This is a message by a user" }] }, + { role: "user", parts: [{ text: "user1: This is a message by a user" }] }, { role: "model", parts: [{ text: "I'm a previous bot reply, that's why there's no user" }], }, - { role: "user", parts: [{ text: "This is a new message by a user" }] }, + { role: "user", parts: [{ text: "user1: This is a new message by a user" }] }, { role: "model", parts: [ diff --git a/spec/lib/completions/dialects/mistral_spec.rb b/spec/lib/completions/dialects/mistral_spec.rb index 2e373bc5..a929768b 100644 --- a/spec/lib/completions/dialects/mistral_spec.rb +++ b/spec/lib/completions/dialects/mistral_spec.rb @@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do prompt = DiscourseAi::Completions::Prompt.new( "You are image bot", - messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]], ) encoded = prompt.encoded_uploads(prompt.messages.last) @@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do content = dialect.translate[1][:content] expect(content).to eq( - [{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }], + [{ type: "text", text: "user1: hello" }, { type: "image_url", image_url: { url: image } }], ) end diff --git a/spec/lib/completions/dialects/nova_spec.rb b/spec/lib/completions/dialects/nova_spec.rb index 865426e2..36b0fcc7 100644 --- a/spec/lib/completions/dialects/nova_spec.rb +++ b/spec/lib/completions/dialects/nova_spec.rb @@ -37,7 +37,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Nova do it "properly formats messages with images" do messages = [ - { type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] }, + { + type: :user, + id: "user1", + content: ["What's in this image?", { upload_id: upload.id }], + }, ] prompt = DiscourseAi::Completions::Prompt.new(messages: messages) diff --git a/spec/lib/completions/dialects/open_ai_compatible_spec.rb b/spec/lib/completions/dialects/open_ai_compatible_spec.rb index 00dea4b0..185a97e3 100644 --- a/spec/lib/completions/dialects/open_ai_compatible_spec.rb +++ b/spec/lib/completions/dialects/open_ai_compatible_spec.rb @@ -34,8 +34,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do messages: [ { type: :user, - content: "Describe this image in a single sentence.", - upload_ids: [upload.id], + content: ["Describe this image in a single sentence.", { upload_id: upload.id }], }, ], ) @@ -49,10 +48,15 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do expect(translated_messages.length).to eq(1) + # no system message support here expected_user_message = { role: "user", content: [ - { type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") }, + { + type: "text", + text: + "You are a bot specializing in image captioning.\nDescribe this image in a single sentence.", + }, { type: "image_url", image_url: { diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index f2f79c5f..ca3bcb46 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -271,7 +271,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do prompt = DiscourseAi::Completions::Prompt.new( "You are image bot", - messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]], ) encoded = prompt.encoded_uploads(prompt.messages.last) @@ -283,6 +283,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do { role: "user", content: [ + { type: "text", text: "user1: hello" }, { type: "image", source: { @@ -291,7 +292,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do data: encoded[0][:base64], }, }, - { type: "text", text: "user1: hello" }, ], }, ], diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index fe7f4eb6..05a87b41 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -211,7 +211,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do prompt = DiscourseAi::Completions::Prompt.new( "You are image bot", - messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]], ) encoded = prompt.encoded_uploads(prompt.messages.last) @@ -248,7 +248,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do { "role" => "user", "parts" => [ - { "text" => "hello" }, + { "text" => "user1: hello" }, { "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } }, ], }, diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index d48bffb5..fc3b8b48 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -492,7 +492,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do prompt = DiscourseAi::Completions::Prompt.new( "You are image bot", - messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]], ) encoded = prompt.encoded_uploads(prompt.messages.last) @@ -517,13 +517,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do { role: "user", content: [ + { type: "text", text: "hello" }, { type: "image_url", image_url: { url: "data:#{encoded[0][:mime_type]};base64,#{encoded[0][:base64]}", }, }, - { type: "text", text: "hello" }, ], name: "user1", }, diff --git a/spec/lib/completions/prompt_messages_builder_spec.rb b/spec/lib/completions/prompt_messages_builder_spec.rb index b162e39c..899b70e4 100644 --- a/spec/lib/completions/prompt_messages_builder_spec.rb +++ b/spec/lib/completions/prompt_messages_builder_spec.rb @@ -2,6 +2,36 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new } + fab!(:user) + fab!(:admin) + fab!(:bot_user) { Fabricate(:user) } + fab!(:other_user) { Fabricate(:user) } + + fab!(:image_upload1) do + Fabricate(:upload, user: user, original_filename: "image.png", extension: "png") + end + fab!(:image_upload2) do + Fabricate(:upload, user: user, original_filename: "image.png", extension: "png") + end + + it "correctly merges user messages with uploads" do + builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1]) + builder.push(type: :user, content: "World", name: "Bob", upload_ids: [2]) + + messages = builder.to_a + + # Check the structure of the merged message + expect(messages.length).to eq(1) + expect(messages[0][:type]).to eq(:user) + + # The content should contain the text and both uploads + content = messages[0][:content] + expect(content).to be_an(Array) + expect(content[0]).to eq("Alice: Hello") + expect(content[1]).to eq({ upload_id: 1 }) + expect(content[2]).to eq("\nBob: World") + expect(content[3]).to eq({ upload_id: 2 }) + end it "should allow merging user messages" do builder.push(type: :user, content: "Hello", name: "Alice") @@ -14,7 +44,7 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2]) expect(builder.to_a).to eq( - [{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }], + [{ type: :user, content: ["Hello", { upload_id: 1 }, { upload_id: 2 }], name: "Alice" }], ) end @@ -64,4 +94,319 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do expect(content).to include("Alice") expect(content).to include("How do I solve this") end + + describe ".messages_from_chat" do + fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) } + fab!(:dm_message1) do + Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "Hello bot") + end + fab!(:dm_message2) do + Fabricate(:chat_message, chat_channel: dm_channel, user: bot_user, message: "Hello human") + end + fab!(:dm_message3) do + Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "How are you?") + end + + fab!(:public_channel) { Fabricate(:category_channel) } + fab!(:public_message1) do + Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Hello everyone") + end + fab!(:public_message2) do + Fabricate(:chat_message, chat_channel: public_channel, user: bot_user, message: "Hi there") + end + + fab!(:thread_original) do + Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Thread starter") + end + fab!(:thread) do + Fabricate(:chat_thread, channel: public_channel, original_message: thread_original) + end + fab!(:thread_reply1) do + Fabricate( + :chat_message, + chat_channel: public_channel, + user: other_user, + message: "Thread reply", + thread: thread, + ) + end + + fab!(:upload) { Fabricate(:upload, user: user) } + fab!(:message_with_upload) do + Fabricate( + :chat_message, + chat_channel: dm_channel, + user: user, + message: "Check this image", + upload_ids: [upload.id], + ) + end + + it "processes messages from direct message channels" do + context = + described_class.messages_from_chat( + dm_message3, + channel: dm_channel, + context_post_ids: nil, + max_messages: 10, + include_uploads: false, + bot_user_ids: [bot_user.id], + instruction_message: nil, + ) + + # this is all we got cause it is assuming threading + expect(context).to eq([{ type: :user, content: "How are you?", name: user.username }]) + end + + it "includes uploads when include_uploads is true" do + message_with_upload.reload + expect(message_with_upload.uploads).to include(upload) + + context = + described_class.messages_from_chat( + message_with_upload, + channel: dm_channel, + context_post_ids: nil, + max_messages: 10, + include_uploads: true, + bot_user_ids: [bot_user.id], + instruction_message: nil, + ) + + # Find the message with upload + message = context.find { |m| m[:content] == ["Check this image", { upload_id: upload.id }] } + expect(message).to be_present + end + + it "doesn't include uploads when include_uploads is false" do + # Make sure the upload is associated with the message + message_with_upload.reload + expect(message_with_upload.uploads).to include(upload) + + context = + described_class.messages_from_chat( + message_with_upload, + channel: dm_channel, + context_post_ids: nil, + max_messages: 10, + include_uploads: false, + bot_user_ids: [bot_user.id], + instruction_message: nil, + ) + + # Find the message with upload + message = context.find { |m| m[:content] == "Check this image" } + expect(message).to be_present + expect(message[:upload_ids]).to be_nil + end + + it "properly handles uploads in public channels with multiple users" do + _first_message = + Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "First message") + + _message_with_upload = + Fabricate( + :chat_message, + chat_channel: public_channel, + user: other_user, + message: "Message with image", + upload_ids: [upload.id], + ) + + last_message = + Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Final message") + + context = + described_class.messages_from_chat( + last_message, + channel: public_channel, + context_post_ids: nil, + max_messages: 3, + include_uploads: true, + bot_user_ids: [bot_user.id], + instruction_message: nil, + ) + + expect(context.length).to eq(1) + content = context.first[:content] + + expect(content.length).to eq(3) + expect(content[0]).to include("First message") + expect(content[0]).to include("Message with image") + expect(content[1]).to include({ upload_id: upload.id }) + expect(content[2]).to include("Final message") + end + end + + describe ".messages_from_post" do + fab!(:pm) do + Fabricate( + :private_message_topic, + title: "This is my special PM", + user: user, + topic_allowed_users: [ + Fabricate.build(:topic_allowed_user, user: user), + Fabricate.build(:topic_allowed_user, user: bot_user), + ], + ) + end + fab!(:first_post) do + Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user") + end + fab!(:second_post) do + Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply") + end + fab!(:third_post) do + Fabricate( + :post, + topic: pm, + user: user, + post_number: 3, + raw: "This is a second reply by the user", + ) + end + + it "handles uploads correctly in topic style messages" do + # Use Discourse's upload format in the post raw content + upload_markdown = "![test|658x372](#{image_upload1.short_url})" + + post_with_upload = + Fabricate( + :post, + topic: pm, + user: admin, + raw: "This is the original #{upload_markdown} I just added", + ) + + UploadReference.create!(target: post_with_upload, upload: image_upload1) + + upload2_markdown = "![test|658x372](#{image_upload2.short_url})" + + post2_with_upload = + Fabricate( + :post, + topic: pm, + user: admin, + raw: "This post has a different image #{upload2_markdown} I just added", + ) + + UploadReference.create!(target: post2_with_upload, upload: image_upload2) + + messages = + described_class.messages_from_post( + post2_with_upload, + style: :topic, + max_posts: 3, + bot_usernames: [bot_user.username], + include_uploads: true, + ) + + # this is not quite ideal yet, images are attached at the end of the post + # long term we may want to extract them out using a regex and create N parts + # so people can talk about multiple images in a single post + # this is the initial ground work though + + expect(messages.length).to eq(1) + content = messages[0][:content] + + # first part + # first image + # second part + # second image + expect(content.length).to eq(4) + expect(content[0]).to include("This is the original") + expect(content[1]).to eq({ upload_id: image_upload1.id }) + expect(content[2]).to include("different image") + expect(content[3]).to eq({ upload_id: image_upload2.id }) + end + + context "with limited context" do + it "respects max_context_posts" do + context = + described_class.messages_from_post( + third_post, + max_posts: 1, + bot_usernames: [bot_user.username], + include_uploads: false, + ) + + expect(context).to contain_exactly( + *[{ type: :user, id: user.username, content: third_post.raw }], + ) + end + end + + it "includes previous posts ordered by post_number" do + context = + described_class.messages_from_post( + third_post, + max_posts: 10, + bot_usernames: [bot_user.username], + include_uploads: false, + ) + + expect(context).to eq( + [ + { type: :user, content: "This is a reply by the user", id: user.username }, + { type: :model, content: "This is a bot reply" }, + { type: :user, content: "This is a second reply by the user", id: user.username }, + ], + ) + end + + it "only include regular posts" do + first_post.update!(post_type: Post.types[:whisper]) + + context = + described_class.messages_from_post( + third_post, + max_posts: 10, + bot_usernames: [bot_user.username], + include_uploads: false, + ) + + # skips leading model reply which makes no sense cause first post was whisper + expect(context).to eq( + [{ type: :user, content: "This is a second reply by the user", id: user.username }], + ) + end + + context "with custom prompts" do + it "When post custom prompt is present, we use that instead of the post content" do + custom_prompt = [ + [ + { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json, + "time", + "tool_call", + ], + [ + { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, + "time", + "tool", + ], + ["I replied to the time command", bot_user.username], + ] + + PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) + + context = + described_class.messages_from_post( + third_post, + max_posts: 10, + bot_usernames: [bot_user.username], + include_uploads: false, + ) + + expect(context).to eq( + [ + { type: :user, content: "This is a reply by the user", id: user.username }, + { type: :tool_call, content: custom_prompt.first.first, id: "time" }, + { type: :tool, id: "time", content: custom_prompt.second.first }, + { type: :model, content: custom_prompt.third.first }, + { type: :user, content: "This is a second reply by the user", id: user.username }, + ], + ) + end + end + end end diff --git a/spec/lib/completions/prompt_spec.rb b/spec/lib/completions/prompt_spec.rb index fe2d1fa1..dafae3bd 100644 --- a/spec/lib/completions/prompt_spec.rb +++ b/spec/lib/completions/prompt_spec.rb @@ -25,34 +25,21 @@ RSpec.describe DiscourseAi::Completions::Prompt do end describe "image support" do - it "allows adding uploads to messages" do + it "allows adding uploads inline in messages" do upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) prompt.max_pixels = 300 - prompt.push(type: :user, content: "hello", upload_ids: [upload.id]) + prompt.push( + type: :user, + content: ["this is an image", { upload_id: upload.id }, "this was an image"], + ) - expect(prompt.messages.last[:upload_ids]).to eq([upload.id]) - expect(prompt.max_pixels).to eq(300) + encoded = prompt.content_with_encoded_uploads(prompt.messages.last[:content]) - encoded = prompt.encoded_uploads(prompt.messages.last) - - expect(encoded.length).to eq(1) - expect(encoded[0][:mime_type]).to eq("image/jpeg") - - old_base64 = encoded[0][:base64] - - prompt.max_pixels = 1_000_000 - - encoded = prompt.encoded_uploads(prompt.messages.last) - - expect(encoded.length).to eq(1) - expect(encoded[0][:mime_type]).to eq("image/jpeg") - - new_base64 = encoded[0][:base64] - - expect(new_base64.length).to be > old_base64.length - expect(new_base64.length).to be > 0 - expect(old_base64.length).to be > 0 + expect(encoded.length).to eq(3) + expect(encoded[0]).to eq("this is an image") + expect(encoded[1][:mime_type]).to eq("image/jpeg") + expect(encoded[2]).to eq("this was an image") end end diff --git a/spec/lib/modules/ai_bot/bot_spec.rb b/spec/lib/modules/ai_bot/bot_spec.rb index 3432ad28..025a37f9 100644 --- a/spec/lib/modules/ai_bot/bot_spec.rb +++ b/spec/lib/modules/ai_bot/bot_spec.rb @@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new) bot.reply( - { conversation_context: [{ type: :user, content: "test" }] }, + DiscourseAi::AiBot::BotContext.new(messages: [{ type: :user, content: "test" }]), ) do |_partial, _cancel, _placeholder| # we just need the block so bot has something to call with results end @@ -74,7 +74,10 @@ RSpec.describe DiscourseAi::AiBot::Bot do HTML - context = { conversation_context: [{ type: :user, content: "Does my site has tags?" }] } + context = + DiscourseAi::AiBot::BotContext.new( + messages: [{ type: :user, content: "Does my site has tags?" }], + ) DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do bot.reply(context) do |_bot_reply_post, cancel, placeholder| diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 8949ae13..8bd6634d 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -36,13 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end let(:context) do - { + DiscourseAi::AiBot::BotContext.new( site_url: Discourse.base_url, site_title: "test site title", site_description: "test site description", time: Time.zone.now, participants: topic_with_users.allowed_users.map(&:username).join(", "), - } + ) end fab!(:admin) @@ -307,7 +307,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new } let(:with_cc) do - context.merge(conversation_context: [{ content: "Tell me the time", type: :user }]) + context.messages = [{ content: "Tell me the time", type: :user }] + context end context "when a persona has no uploads" do @@ -345,17 +346,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new # this means that we will consolidate - ctx = - with_cc.merge( - conversation_context: [ - { content: "Tell me the time", type: :user }, - { content: "the time is 1", type: :model }, - { content: "in france?", type: :user }, - ], - ) + context.messages = [ + { content: "Tell me the time", type: :user }, + { content: "the time is 1", type: :model }, + { content: "in france?", type: :user }, + ] DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do - custom_persona.craft_prompt(ctx).messages.first[:content] + custom_persona.craft_prompt(context).messages.first[:content] end message = @@ -397,7 +395,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id]) EmbeddingsGenerationStubs.hugging_face_service( - with_cc.dig(:conversation_context, 0, :content), + with_cc.messages.dig(0, :content), prompt_cc_embeddings, ) end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index d65842e0..4df2ae5d 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -267,7 +267,10 @@ RSpec.describe DiscourseAi::AiBot::Playground do prompts = inner_prompts end - expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id]) + content = prompts[0].messages[1][:content] + + expect(content).to include({ upload_id: upload.id }) + expect(prompts[0].max_pixels).to eq(1000) post.topic.reload @@ -1154,79 +1157,4 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(playground.available_bot_usernames).to include(persona.user.username) end end - - describe "#conversation_context" do - context "with limited context" do - before do - @old_persona = playground.bot.persona - persona = Fabricate(:ai_persona, max_context_posts: 1) - playground.bot.persona = persona.class_instance.new - end - - after { playground.bot.persona = @old_persona } - - it "respects max_context_post" do - context = playground.conversation_context(third_post) - - expect(context).to contain_exactly( - *[{ type: :user, id: user.username, content: third_post.raw }], - ) - end - end - - xit "includes previous posts ordered by post_number" do - context = playground.conversation_context(third_post) - - expect(context).to contain_exactly( - *[ - { type: :user, id: user.username, content: third_post.raw }, - { type: :model, content: second_post.raw }, - { type: :user, id: user.username, content: first_post.raw }, - ], - ) - end - - xit "only include regular posts" do - first_post.update!(post_type: Post.types[:whisper]) - - context = playground.conversation_context(third_post) - - # skips leading model reply which makes no sense cause first post was whisper - expect(context).to contain_exactly( - *[{ type: :user, id: user.username, content: third_post.raw }], - ) - end - - context "with custom prompts" do - it "When post custom prompt is present, we use that instead of the post content" do - custom_prompt = [ - [ - { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json, - "time", - "tool_call", - ], - [ - { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, - "time", - "tool", - ], - ["I replied to the time command", bot_user.username], - ] - - PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) - - context = playground.conversation_context(third_post) - - expect(context).to contain_exactly( - *[ - { type: :user, id: user.username, content: first_post.raw }, - { type: :tool_call, content: custom_prompt.first.first, id: "time" }, - { type: :tool, id: "time", content: custom_prompt.second.first }, - { type: :model, content: custom_prompt.third.first }, - { type: :user, id: user.username, content: third_post.raw }, - ], - ) - end - end - end end diff --git a/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb index 058e28d2..26d47528 100644 --- a/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb +++ b/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb @@ -34,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::CreateArtifact do { html_body: "hello" }, bot_user: Fabricate(:user), llm: llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) tool.parameters = { name: "hello", specification: "hello spec" } diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb index 3eb1af62..dfb4e9d3 100644 --- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb +++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb @@ -15,9 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") } let(:progress_blk) { Proc.new {} } - let(:dall_e) do - described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {}) - end + let(:dall_e) { described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user) } let(:base64_image) do "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" @@ -30,8 +28,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do { prompts: ["a cat"], aspect_ratio: "tall" }, llm: llm, bot_user: bot_user, - context: { - }, ) data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }] diff --git a/spec/lib/modules/ai_bot/tools/image_spec.rb b/spec/lib/modules/ai_bot/tools/image_spec.rb index 6f14448f..aea98dd9 100644 --- a/spec/lib/modules/ai_bot/tools/image_spec.rb +++ b/spec/lib/modules/ai_bot/tools/image_spec.rb @@ -9,8 +9,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do { prompts: prompts, seeds: [99, 32] }, bot_user: bot_user, llm: llm, - context: { - }, + context: DiscourseAi::AiBot::BotContext.new, ) end diff --git a/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb index 252e7eec..8247c528 100644 --- a/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb +++ b/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb @@ -25,9 +25,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do { url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post2.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) result = tool.invoke {} @@ -46,9 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do { url: "invalid-url" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) result = tool.invoke {} @@ -62,9 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do { url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/99999" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) result = tool.invoke {} @@ -97,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do { url: "https://example.com" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) result = tool.invoke {} @@ -128,9 +120,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do { url: "https://example.com" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(post: post), ) result = tool.invoke {} diff --git a/spec/lib/modules/ai_bot/tools/read_spec.rb b/spec/lib/modules/ai_bot/tools/read_spec.rb index d8ffe8ae..88dc7906 100644 --- a/spec/lib/modules/ai_bot/tools/read_spec.rb +++ b/spec/lib/modules/ai_bot/tools/read_spec.rb @@ -56,9 +56,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do persona_options: { "read_private" => true, }, - context: { - user: admin, - }, + context: DiscourseAi::AiBot::BotContext.new(user: admin), ) results = tool.invoke expect(results[:content]).to include("hello there") @@ -68,9 +66,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do { topic_id: topic_with_tags.id, post_numbers: [post1.post_number] }, bot_user: bot_user, llm: llm, - context: { - user: admin, - }, + context: DiscourseAi::AiBot::BotContext.new(user: admin), ) results = tool.invoke diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 94dba3c7..c7be7522 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -60,9 +60,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do persona_options: persona_options, bot_user: bot_user, llm: llm, - context: { - user: user, - }, + context: DiscourseAi::AiBot::BotContext.new(user: user), ) expect(search.options[:base_query]).to eq("#funny") diff --git a/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb index c9834d25..f3d2a28f 100644 --- a/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb +++ b/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb @@ -47,9 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do persona_options: { "update_algorithm" => "full", }, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) result = tool.invoke {} @@ -93,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do persona_options: { "update_algorithm" => "full", }, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) result = tool.invoke {} @@ -119,9 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do { artifact_id: artifact.id, instructions: "Invalid update" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) result = tool.invoke {} @@ -135,9 +129,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do { artifact_id: -1, instructions: "Update something" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) result = tool.invoke {} @@ -163,9 +155,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do persona_options: { "update_algorithm" => "full", }, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) tool.invoke {} @@ -196,9 +186,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do persona_options: { "update_algorithm" => "full", }, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) .invoke {} end @@ -224,9 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do persona_options: { "update_algorithm" => "full", }, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), ) result = tool.invoke {} @@ -276,9 +262,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do { artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), persona_options: { "update_algorithm" => "diff", }, @@ -346,9 +330,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do { artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" }, bot_user: bot_user, llm: llm_model.to_llm, - context: { - post_id: post.id, - }, + context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post), persona_options: { "update_algorithm" => "diff", }, diff --git a/spec/lib/modules/ai_moderation/spam_scanner_spec.rb b/spec/lib/modules/ai_moderation/spam_scanner_spec.rb index 129d8bdb..97cd4dfa 100644 --- a/spec/lib/modules/ai_moderation/spam_scanner_spec.rb +++ b/spec/lib/modules/ai_moderation/spam_scanner_spec.rb @@ -255,11 +255,12 @@ RSpec.describe DiscourseAi::AiModeration::SpamScanner do prompt = _prompts.first end - content = prompt.messages[1][:content] + # its an array so lets just stringify it to make testing easier + content = prompt.messages[1][:content][0] expect(content).to include(post.topic.title) expect(content).to include(post.raw) - upload_ids = prompt.messages[1][:upload_ids] + upload_ids = prompt.messages[1][:content].map { |m| m[:upload_id] if m.is_a?(Hash) }.compact expect(upload_ids).to be_present expect(upload_ids).to eq(post.upload_ids) diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb index c4b20bf3..e9c9979a 100644 --- a/spec/lib/modules/automation/llm_triage_spec.rb +++ b/spec/lib/modules/automation/llm_triage_spec.rb @@ -199,7 +199,7 @@ describe DiscourseAi::Automation::LlmTriage do triage_prompt = DiscourseAi::Completions::Llm.prompts.last - expect(triage_prompt.messages.last[:upload_ids]).to contain_exactly(post_upload.id) + expect(triage_prompt.messages.last[:content].last).to eq({ upload_id: post_upload.id }) end end diff --git a/spec/models/ai_tool_spec.rb b/spec/models/ai_tool_spec.rb index 0df68dea..45ec3175 100644 --- a/spec/models/ai_tool_spec.rb +++ b/spec/models/ai_tool_spec.rb @@ -5,6 +5,7 @@ RSpec.describe AiTool do let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } fab!(:topic) fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") } + fab!(:bot_user) { Discourse.system_user } def create_tool( parameters: nil, @@ -16,7 +17,8 @@ RSpec.describe AiTool do name: "test #{SecureRandom.uuid}", tool_name: "test_#{SecureRandom.uuid.underscore}", description: "test", - parameters: parameters || [{ name: "query", type: "string", desciption: "perform a search" }], + parameters: + parameters || [{ name: "query", type: "string", description: "perform a search" }], script: script || "function invoke(params) { return params; }", created_by_id: 1, summary: "Test tool summary", @@ -32,11 +34,11 @@ RSpec.describe AiTool do { name: tool.tool_name, description: "test", - parameters: [{ name: "query", type: "string", desciption: "perform a search" }], + parameters: [{ name: "query", type: "string", description: "perform a search" }], }, ) - runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) expect(runner.invoke).to eq("query" => "test") end @@ -57,7 +59,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil) stub_request(verb, "https://example.com/api").with( body: "{\"data\":\"test data\"}", @@ -83,7 +85,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) stub_request(:get, "https://example.com/test").with( headers: { @@ -110,7 +112,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({}, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({}, llm: nil, bot_user: nil) stub_request(:get, "https://example.com/").to_return( status: 200, @@ -134,7 +136,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) stub_request(:get, "https://example.com/test").with( headers: { @@ -160,13 +162,16 @@ RSpec.describe AiTool do } JS + tool = create_tool(script: script) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) + stub_request(:get, "https://example.com/test").to_return do sleep 0.01 { status: 200, body: "Hello World", headers: {} } end tool = create_tool(script: script) - runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) runner.timeout = 10 @@ -184,7 +189,7 @@ RSpec.describe AiTool do tool = create_tool(script: script) - runner = tool.runner({}, llm: llm, bot_user: nil, context: {}) + runner = tool.runner({}, llm: llm, bot_user: nil) result = runner.invoke expect(result).to eq("Hello") @@ -209,7 +214,7 @@ RSpec.describe AiTool do responses = ["Hello ", "World"] DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| - runner = tool.runner({}, llm: llm, bot_user: nil, context: {}) + runner = tool.runner({}, llm: llm, bot_user: nil) result = runner.invoke prompts = _prompts end @@ -232,7 +237,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil) runner.timeout = 5 @@ -295,7 +300,7 @@ RSpec.describe AiTool do RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id]) - result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke + result = tool.runner({}, llm: nil, bot_user: nil).invoke expected = [ [{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }], @@ -316,7 +321,7 @@ RSpec.describe AiTool do # this part of the API is a bit awkward, maybe we should do it # automatically RagDocumentFragment.update_target_uploads(tool, [upload1.id, upload2.id]) - result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke + result = tool.runner({}, llm: nil, bot_user: nil).invoke expected = [ [{ "fragment" => "48 49 50", "metadata" => nil }], @@ -340,7 +345,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil) result = runner.invoke @@ -364,7 +369,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil) result = runner.invoke post_hash = result["post"] @@ -393,7 +398,7 @@ RSpec.describe AiTool do JS tool = create_tool(script: script) - runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {}) + runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil) result = runner.invoke @@ -401,4 +406,158 @@ RSpec.describe AiTool do expect(result["rows"].first["title"]).to eq(topic.title) end end + + context "when using the chat API" do + before(:each) do + skip "Chat plugin tests skipped because Chat module is not defined." unless defined?(Chat) + SiteSetting.chat_enabled = true + end + + fab!(:chat_user) { Fabricate(:user) } + fab!(:chat_channel) do + Fabricate(:chat_channel).tap do |channel| + Fabricate( + :user_chat_channel_membership, + user: chat_user, + chat_channel: channel, + following: true, + ) + end + end + + it "can create a chat message" do + script = <<~JS + function invoke(params) { + return discourse.createChatMessage({ + channel_name: params.channel_name, + username: params.username, + message: params.message + }); + } + JS + + tool = create_tool(script: script) + runner = + tool.runner( + { + "channel_name" => chat_channel.name, + "username" => chat_user.username, + "message" => "Hello from the tool!", + }, + llm: nil, + bot_user: bot_user, # The user *running* the tool doesn't affect sender + ) + + initial_message_count = Chat::Message.count + result = runner.invoke + + expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}" + expect(result["message"]).to eq("Hello from the tool!") + expect(result["created_at"]).to be_present + expect(result).not_to have_key("error") + + # Verify message was actually created in the database + expect(Chat::Message.count).to eq(initial_message_count + 1) + created_message = Chat::Message.find_by(id: result["message_id"]) + + expect(created_message).not_to be_nil + expect(created_message.message).to eq("Hello from the tool!") + expect(created_message.user_id).to eq(chat_user.id) # Message is sent AS the specified user + expect(created_message.chat_channel_id).to eq(chat_channel.id) + end + + it "can create a chat message using channel slug" do + chat_channel.update!(name: "My Test Channel", slug: "my-test-channel") + expect(chat_channel.slug).to eq("my-test-channel") + + script = <<~JS + function invoke(params) { + return discourse.createChatMessage({ + channel_name: params.channel_slug, // Using slug here + username: params.username, + message: params.message + }); + } + JS + + tool = create_tool(script: script) + runner = + tool.runner( + { + "channel_slug" => chat_channel.slug, + "username" => chat_user.username, + "message" => "Hello via slug!", + }, + llm: nil, + bot_user: bot_user, + ) + + result = runner.invoke + + expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}" + # see: https://github.com/rubyjs/mini_racer/issues/348 + # expect(result["message_id"]).to be_a(Integer) + + created_message = Chat::Message.find_by(id: result["message_id"]) + expect(created_message).not_to be_nil + expect(created_message.message).to eq("Hello via slug!") + expect(created_message.chat_channel_id).to eq(chat_channel.id) + end + + it "returns an error if the channel is not found" do + script = <<~JS + function invoke(params) { + return discourse.createChatMessage({ + channel_name: "non_existent_channel", + username: params.username, + message: params.message + }); + } + JS + + tool = create_tool(script: script) + runner = + tool.runner( + { "username" => chat_user.username, "message" => "Test" }, + llm: nil, + bot_user: bot_user, + ) + + initial_message_count = Chat::Message.count + expect { runner.invoke }.to raise_error( + MiniRacer::RuntimeError, + /Channel not found: non_existent_channel/, + ) + + expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created + end + + it "returns an error if the user is not found" do + script = <<~JS + function invoke(params) { + return discourse.createChatMessage({ + channel_name: params.channel_name, + username: "non_existent_user", + message: params.message + }); + } + JS + + tool = create_tool(script: script) + runner = + tool.runner( + { "channel_name" => chat_channel.name, "message" => "Test" }, + llm: nil, + bot_user: bot_user, + ) + + initial_message_count = Chat::Message.count + expect { runner.invoke }.to raise_error( + MiniRacer::RuntimeError, + /User not found: non_existent_user/, + ) + + expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created + end + end end