diff --git a/app/controllers/discourse_ai/admin/ai_tools_controller.rb b/app/controllers/discourse_ai/admin/ai_tools_controller.rb
index caf7fae6..6c148f1a 100644
--- a/app/controllers/discourse_ai/admin/ai_tools_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_tools_controller.rb
@@ -55,7 +55,7 @@ module DiscourseAi
# we need an llm so we have a tokenizer
# but will do without if none is available
llm = LlmModel.first&.to_llm
- runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user, context: {})
+ runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user)
result = runner.invoke
if result.is_a?(Hash) && result[:error]
diff --git a/app/jobs/regular/stream_discover_reply.rb b/app/jobs/regular/stream_discover_reply.rb
index b242616e..48183f6e 100644
--- a/app/jobs/regular/stream_discover_reply.rb
+++ b/app/jobs/regular/stream_discover_reply.rb
@@ -30,9 +30,13 @@ module Jobs
base = { query: query, model_used: llm_model.display_name }
- bot.reply(
- { conversation_context: [{ type: :user, content: query }], skip_tool_details: true },
- ) do |partial|
+ context =
+ DiscourseAi::AiBot::BotContext.new(
+ messages: [{ type: :user, content: query }],
+ skip_tool_details: true,
+ )
+
+ bot.reply(context) do |partial|
streamed_reply << partial
# Throttle updates.
diff --git a/app/models/ai_tool.rb b/app/models/ai_tool.rb
index ba3b4098..d6fde8b2 100644
--- a/app/models/ai_tool.rb
+++ b/app/models/ai_tool.rb
@@ -35,7 +35,7 @@ class AiTool < ActiveRecord::Base
tool_name.presence || name
end
- def runner(parameters, llm:, bot_user:, context: {})
+ def runner(parameters, llm:, bot_user:, context: nil)
DiscourseAi::AiBot::ToolRunner.new(
parameters: parameters,
llm: llm,
@@ -59,86 +59,166 @@ class AiTool < ActiveRecord::Base
def self.preamble
<<~JS
- /**
- * Tool API Quick Reference
- *
- * Entry Functions
- *
- * invoke(parameters): Main function. Receives parameters (Object). Must return a JSON-serializable value.
- * Example:
- * function invoke(parameters) { return "result"; }
- *
- * details(): Optional. Returns a string describing the tool.
- * Example:
- * function details() { return "Tool description."; }
- *
- * Provided Objects
- *
- * 1. http
- * http.get(url, options?): Performs an HTTP GET request.
- * Parameters:
- * url (string): The request URL.
- * options (Object, optional):
- * headers (Object): Request headers.
- * Returns:
- * { status: number, body: string }
- *
- * http.post(url, options?): Performs an HTTP POST request.
- * Parameters:
- * url (string): The request URL.
- * options (Object, optional):
- * headers (Object): Request headers.
- * body (string): Request body.
- * Returns:
- * { status: number, body: string }
- *
- * (also available: http.put, http.patch, http.delete)
- *
- * Note: Max 20 HTTP requests per execution.
- *
- * 2. llm
- * llm.truncate(text, length): Truncates text to a specified token length.
- * Parameters:
- * text (string): Text to truncate.
- * length (number): Max tokens.
- * Returns:
- * Truncated string.
- *
- * 3. index
- * index.search(query, options?): Searches indexed documents.
- * Parameters:
- * query (string): Search query.
- * options (Object, optional):
- * filenames (Array): Limit search to specific files.
- * limit (number): Max fragments (up to 200).
- * Returns:
- * Array of { fragment: string, metadata: string }
- *
- * 4. upload
- * upload.create(filename, base_64_content): Uploads a file.
- * Parameters:
- * filename (string): Name of the file.
- * base_64_content (string): Base64 encoded file content.
- * Returns:
- * { id: number, short_url: string }
- *
- * 5. chain
- * chain.setCustomRaw(raw): Sets the body of the post and exist chain.
- * Parameters:
- * raw (string): raw content to add to post.
- *
- * Constraints
- *
- * Execution Time: ≤ 2000ms
- * Memory: ≤ 10MB
- * HTTP Requests: ≤ 20 per execution
- * Exceeding limits will result in errors or termination.
- *
- * Security
- *
- * Sandboxed Environment: No access to system or global objects.
- * No File System Access: Cannot read or write files.
- */
+ /**
+ * Tool API Quick Reference
+ *
+ * Entry Functions
+ *
+ * invoke(parameters): Main function. Receives parameters defined in the tool's signature (Object).
+ * Must return a JSON-serializable value (e.g., string, number, object, array).
+ * Example:
+ * function invoke(parameters) { return { result: "Data processed", input: parameters.query }; }
+ *
+ * details(): Optional function. Returns a string (can include basic HTML) describing
+ * the tool's action after invocation, often using data from the invocation.
+ * This is displayed in the chat interface.
+ * Example:
+ * let lastUrl;
+ * function invoke(parameters) {
+ * lastUrl = parameters.url;
+ * // ... perform action ...
+ * return { success: true, content: "..." };
+ * }
+ * function details() {
+ * return `Browsed: ${lastUrl}`;
+ * }
+ *
+ * Provided Objects & Functions
+ *
+ * 1. http
+ * Performs HTTP requests. Max 20 requests per execution.
+ *
+ * http.get(url, options?): Performs GET request.
+ * Parameters:
+ * url (string): The request URL.
+ * options (Object, optional):
+ * headers (Object): Request headers (e.g., { "Authorization": "Bearer key" }).
+ * Returns: { status: number, body: string }
+ *
+ * http.post(url, options?): Performs POST request.
+ * Parameters:
+ * url (string): The request URL.
+ * options (Object, optional):
+ * headers (Object): Request headers.
+ * body (string | Object): Request body. If an object, it's stringified as JSON.
+ * Returns: { status: number, body: string }
+ *
+ * http.put(url, options?): Performs PUT request (similar to POST).
+ * http.patch(url, options?): Performs PATCH request (similar to POST).
+ * http.delete(url, options?): Performs DELETE request (similar to GET/POST).
+ *
+ * 2. llm
+ * Interacts with the Language Model.
+ *
+ * llm.truncate(text, length): Truncates text to a specified token length based on the configured LLM's tokenizer.
+ * Parameters:
+ * text (string): Text to truncate.
+ * length (number): Maximum number of tokens.
+ * Returns: string (truncated text)
+ *
+ * llm.generate(prompt): Generates text using the configured LLM associated with the tool runner.
+ * Parameters:
+ * prompt (string | Object): The prompt. Can be a simple string or an object
+ * like { messages: [{ type: "system", content: "..." }, { type: "user", content: "..." }] }.
+ * Returns: string (generated text)
+ *
+ * 3. index
+ * Searches attached RAG (Retrieval-Augmented Generation) documents linked to this tool.
+ *
+ * index.search(query, options?): Searches indexed document fragments.
+ * Parameters:
+ * query (string): The search query used for semantic search.
+ * options (Object, optional):
+ * filenames (Array): Filter search to fragments from specific uploaded filenames.
+ * limit (number): Maximum number of fragments to return (default: 10, max: 200).
+ * Returns: Array<{ fragment: string, metadata: string | null }> - Ordered by relevance.
+ *
+ * 4. upload
+ * Handles file uploads within Discourse.
+ *
+ * upload.create(filename, base_64_content): Uploads a file created by the tool, making it available in Discourse.
+ * Parameters:
+ * filename (string): The desired name for the file (basename is used for security).
+ * base_64_content (string): Base64 encoded content of the file.
+ * Returns: { id: number, url: string, short_url: string } - Details of the created upload record.
+ *
+ * 5. chain
+ * Controls the execution flow.
+ *
+ * chain.setCustomRaw(raw): Sets the final raw content of the bot's post and immediately
+ * stops the tool execution chain. Useful for tools that directly
+ * generate the full response content (e.g., image generation tools attaching the image markdown).
+ * Parameters:
+ * raw (string): The raw Markdown content for the post.
+ * Returns: void
+ *
+ * 6. discourse
+ * Interacts with Discourse specific features. Access is generally performed as the SystemUser.
+ *
+ * discourse.search(params): Performs a Discourse search.
+ * Parameters:
+ * params (Object): Search parameters (e.g., { search_query: "keyword", with_private: true, max_results: 10 }).
+ * `with_private: true` searches across all posts visible to the SystemUser. `result_style: 'detailed'` is used by default.
+ * Returns: Object (Discourse search results structure, includes posts, topics, users etc.)
+ *
+ * discourse.getPost(post_id): Retrieves details for a specific post.
+ * Parameters:
+ * post_id (number): The ID of the post.
+ * Returns: Object (Post details including `raw`, nested `topic` object with ListableTopicSerializer structure) or null if not found/accessible.
+ *
+ * discourse.getTopic(topic_id): Retrieves details for a specific topic.
+ * Parameters:
+ * topic_id (number): The ID of the topic.
+ * Returns: Object (Topic details using ListableTopicSerializer structure) or null if not found/accessible.
+ *
+ * discourse.getUser(user_id_or_username): Retrieves details for a specific user.
+ * Parameters:
+ * user_id_or_username (number | string): The ID or username of the user.
+ * Returns: Object (User details using UserSerializer structure) or null if not found.
+ *
+ * discourse.getPersona(name): Gets an object representing another AI Persona configured on the site.
+ * Parameters:
+ * name (string): The name of the target persona.
+ * Returns: Object { respondTo: function(params) } or null if persona not found.
+ * respondTo(params): Instructs the target persona to generate a response within the current context (e.g., replying to the same post or chat message).
+ * Parameters:
+ * params (Object, optional): { instructions: string, whisper: boolean }
+ * Returns: { success: boolean, post_id?: number, post_number?: number, message_id?: number } or { error: string }
+ *
+ * discourse.createChatMessage(params): Creates a new message in a Discourse Chat channel.
+ * Parameters:
+ * params (Object): { channel_name: string, username: string, message: string }
+ * `channel_name` can be the channel name or slug.
+ * `username` specifies the user who should appear as the sender. The user must exist.
+ * The sending user must have permission to post in the channel.
+ * Returns: { success: boolean, message_id?: number, message?: string, created_at?: string } or { error: string }
+ *
+ * 7. context
+ * An object containing information about the environment where the tool is being run.
+ * Available properties depend on the invocation context, but may include:
+ * post_id (number): ID of the post triggering the tool (if in a Post context).
+ * topic_id (number): ID of the topic (if in a Post context).
+ * private_message (boolean): Whether the context is a private message (in Post context).
+ * message_id (number): ID of the chat message triggering the tool (if in Chat context).
+ * channel_id (number): ID of the chat channel (if in Chat context).
+ * user (Object): Details of the user invoking the tool/persona (structure may vary, often null or SystemUser details unless explicitly passed).
+ * participants (string): Comma-separated list of usernames in a PM (if applicable).
+ * // ... other potential context-specific properties added by the calling environment.
+ *
+ * Constraints
+ *
+ * Execution Time: ≤ 2000ms (default timeout in milliseconds) - This timer *pauses* during external HTTP requests or LLM calls initiated via `http.*` or `llm.generate`, but applies to the script's own processing time.
+ * Memory: ≤ 10MB (V8 heap limit)
+ * Stack Depth: ≤ 20 (Marshal stack depth limit for Ruby interop)
+ * HTTP Requests: ≤ 20 per execution
+ * Exceeding limits will result in errors or termination (e.g., timeout error, out-of-memory error, TooManyRequestsError).
+ *
+ * Security
+ *
+ * Sandboxed Environment: The script runs in a restricted V8 JavaScript environment (via MiniRacer).
+ * No direct access to browser or environment, browser globals (like `window` or `document`), or the host system's file system.
+ * Network requests are proxied through the Discourse backend, not made directly from the sandbox.
+ */
JS
end
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index bb4bb07e..3248d031 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -75,9 +75,9 @@ module DiscourseAi
def force_tool_if_needed(prompt, context)
return if prompt.tool_choice == :none
- context[:chosen_tools] ||= []
+ context.chosen_tools ||= []
forced_tools = persona.force_tool_use.map { |tool| tool.name }
- force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
+ force_tool = forced_tools.find { |name| !context.chosen_tools.include?(name) }
if force_tool && persona.forced_tool_count > 0
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
@@ -85,7 +85,7 @@ module DiscourseAi
end
if force_tool
- context[:chosen_tools] << force_tool
+ context.chosen_tools << force_tool
prompt.tool_choice = force_tool
else
prompt.tool_choice = nil
@@ -93,6 +93,9 @@ module DiscourseAi
end
def reply(context, &update_blk)
+ unless context.is_a?(BotContext)
+ raise ArgumentError, "context must be an instance of BotContext"
+ end
llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm)
@@ -100,7 +103,7 @@ module DiscourseAi
ongoing_chain = true
raw_context = []
- user = context[:user]
+ user = context.user
llm_kwargs = { user: user }
llm_kwargs[:temperature] = persona.temperature if persona.temperature
@@ -277,27 +280,15 @@ module DiscourseAi
name: tool.name,
}
- if tool.standalone?
- standalone_context =
- context.dup.merge(
- conversation_context: [
- context[:conversation_context].last,
- tool_call_message,
- tool_message,
- ],
- )
- prompt = persona.craft_prompt(standalone_context)
- else
- prompt.push(**tool_call_message)
- prompt.push(**tool_message)
- end
+ prompt.push(**tool_call_message)
+ prompt.push(**tool_message)
raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name]
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
end
def invoke_tool(tool, llm, cancel, context, &update_blk)
- show_placeholder = !context[:skip_tool_details] && !tool.class.allow_partial_tool_calls?
+ show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
diff --git a/lib/ai_bot/bot_context.rb b/lib/ai_bot/bot_context.rb
new file mode 100644
index 00000000..ca32262a
--- /dev/null
+++ b/lib/ai_bot/bot_context.rb
@@ -0,0 +1,107 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module AiBot
+ class BotContext
+ attr_accessor :messages,
+ :topic_id,
+ :post_id,
+ :private_message,
+ :custom_instructions,
+ :user,
+ :skip_tool_details,
+ :participants,
+ :chosen_tools,
+ :message_id,
+ :channel_id,
+ :context_post_ids
+
+ def initialize(
+ post: nil,
+ participants: nil,
+ user: nil,
+ skip_tool_details: nil,
+ messages: [],
+ custom_instructions: nil,
+ site_url: nil,
+ site_title: nil,
+ site_description: nil,
+ time: nil,
+ message_id: nil,
+ channel_id: nil,
+ context_post_ids: nil
+ )
+ @participants = participants
+ @user = user
+ @skip_tool_details = skip_tool_details
+ @messages = messages
+ @custom_instructions = custom_instructions
+
+ @message_id = message_id
+ @channel_id = channel_id
+ @context_post_ids = context_post_ids
+
+ @site_url = site_url
+ @site_title = site_title
+ @site_description = site_description
+ @time = time
+
+ if post
+ @post_id = post.id
+ @topic_id = post.topic_id
+ @private_message = post.topic.private_message?
+ @participants ||= post.topic.allowed_users.map(&:username).join(", ") if @private_message
+ @user = post.user
+ end
+ end
+
+ # these are strings that can be safely interpolated into templates
+ TEMPLATE_PARAMS = %w[time site_url site_title site_description participants]
+
+ def lookup_template_param(key)
+ public_send(key.to_sym) if TEMPLATE_PARAMS.include?(key)
+ end
+
+ def time
+ @time ||= Time.zone.now
+ end
+
+ def site_url
+ @site_url ||= Discourse.base_url
+ end
+
+ def site_title
+ @site_title ||= SiteSetting.title
+ end
+
+ def site_description
+ @site_description ||= SiteSetting.site_description
+ end
+
+ def private_message?
+ @private_message
+ end
+
+ def to_json
+ {
+ messages: @messages,
+ topic_id: @topic_id,
+ post_id: @post_id,
+ private_message: @private_message,
+ custom_instructions: @custom_instructions,
+ username: @user&.username,
+ user_id: @user&.id,
+ participants: @participants,
+ chosen_tools: @chosen_tools,
+ message_id: @message_id,
+ channel_id: @channel_id,
+ context_post_ids: @context_post_ids,
+ site_url: @site_url,
+ site_title: @site_title,
+ site_description: @site_description,
+ skip_tool_details: @skip_tool_details,
+ }
+ end
+ end
+ end
+end
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index 33c2907e..0d6745a3 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -163,7 +163,7 @@ module DiscourseAi
def craft_prompt(context, llm: nil)
system_insts =
system_prompt.gsub(/\{(\w+)\}/) do |match|
- found = context[match[1..-2].to_sym]
+ found = context.lookup_template_param(match[1..-2])
found.nil? ? match : found.to_s
end
@@ -180,16 +180,16 @@ module DiscourseAi
)
end
- if context[:custom_instructions].present?
+ if context.custom_instructions.present?
prompt_insts << "\n"
- prompt_insts << context[:custom_instructions]
+ prompt_insts << context.custom_instructions
end
fragments_guidance =
rag_fragments_prompt(
- context[:conversation_context].to_a,
+ context.messages,
llm: question_consolidator_llm,
- user: context[:user],
+ user: context.user,
)&.strip
prompt_insts << fragments_guidance if fragments_guidance.present?
@@ -197,9 +197,9 @@ module DiscourseAi
prompt =
DiscourseAi::Completions::Prompt.new(
prompt_insts,
- messages: context[:conversation_context].to_a,
- topic_id: context[:topic_id],
- post_id: context[:post_id],
+ messages: context.messages,
+ topic_id: context.topic_id,
+ post_id: context.post_id,
)
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled
diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb
index 59533efd..685683fd 100644
--- a/lib/ai_bot/playground.rb
+++ b/lib/ai_bot/playground.rb
@@ -227,90 +227,17 @@ module DiscourseAi
schedule_bot_reply(post) if can_attach?(post)
end
- def conversation_context(post, style: nil)
- # Pay attention to the `post_number <= ?` here.
- # We want to inject the last post as context because they are translated differently.
-
- # also setting default to 40, allowing huge contexts costs lots of tokens
- max_posts = 40
- if bot.persona.class.respond_to?(:max_context_posts)
- max_posts = bot.persona.class.max_context_posts || 40
- end
-
- post_types = [Post.types[:regular]]
- post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
-
- context =
- post
- .topic
- .posts
- .joins(:user)
- .joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
- .where("post_number <= ?", post.post_number)
- .order("post_number desc")
- .where("post_type in (?)", post_types)
- .limit(max_posts)
- .pluck(
- "posts.raw",
- "users.username",
- "post_custom_prompts.custom_prompt",
- "(
- SELECT array_agg(ref.upload_id)
- FROM upload_references ref
- WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
- ) as upload_ids",
- )
-
- builder = DiscourseAi::Completions::PromptMessagesBuilder.new
- builder.topic = post.topic
-
- context.reverse_each do |raw, username, custom_prompt, upload_ids|
- custom_prompt_translation =
- Proc.new do |message|
- # We can't keep backwards-compatibility for stored functions.
- # Tool syntax requires a tool_call_id which we don't have.
- if message[2] != "function"
- custom_context = {
- content: message[0],
- type: message[2].present? ? message[2].to_sym : :model,
- }
-
- custom_context[:id] = message[1] if custom_context[:type] != :model
- custom_context[:name] = message[3] if message[3]
-
- thinking = message[4]
- custom_context[:thinking] = thinking if thinking
-
- builder.push(**custom_context)
- end
- end
-
- if custom_prompt.present?
- custom_prompt.each(&custom_prompt_translation)
- else
- context = {
- content: raw,
- type: (available_bot_usernames.include?(username) ? :model : :user),
- }
-
- context[:id] = username if context[:type] == :user
-
- if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
- context[:upload_ids] = upload_ids.compact
- end
-
- builder.push(**context)
- end
- end
-
- builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
- end
-
def title_playground(post, user)
- context = conversation_context(post)
+ messages =
+ DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
+ post,
+ max_posts: 5,
+ bot_usernames: available_bot_usernames,
+ include_uploads: bot.persona.class.vision_enabled,
+ )
bot
- .get_updated_title(context, post, user)
+ .get_updated_title(messages, post, user)
.tap do |new_title|
PostRevisor.new(post.topic.first_post, post.topic).revise!(
bot.bot_user,
@@ -326,83 +253,6 @@ module DiscourseAi
)
end
- def chat_context(message, channel, persona_user, context_post_ids)
- has_vision = bot.persona.class.vision_enabled
- include_thread_titles = !channel.direct_message_channel? && !message.thread_id
-
- current_id = message.id
- if !channel.direct_message_channel?
- # we are interacting via mentions ... strip mention
- instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
- end
-
- messages = nil
-
- max_messages = 40
- if bot.persona.class.respond_to?(:max_context_posts)
- max_messages = bot.persona.class.max_context_posts || 40
- end
-
- if !message.thread_id && channel.direct_message_channel?
- messages = [message]
- elsif !channel.direct_message_channel? && !message.thread_id
- messages =
- Chat::Message
- .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
- .where(chat_channel_id: channel.id)
- .where(
- "chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
- )
- .order(id: :desc)
- .limit(max_messages)
- .to_a
- .reverse
- end
-
- messages ||=
- ChatSDK::Thread.last_messages(
- thread_id: message.thread_id,
- guardian: Discourse.system_user.guardian,
- page_size: max_messages,
- )
-
- builder = DiscourseAi::Completions::PromptMessagesBuilder.new
-
- guardian = Guardian.new(message.user)
- if context_post_ids
- builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision)
- end
-
- messages.each do |m|
- # restore stripped message
- m.message = instruction_message if m.id == current_id && instruction_message
-
- if available_bot_user_ids.include?(m.user_id)
- builder.push(type: :model, content: m.message)
- else
- upload_ids = nil
- upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
- mapped_message = m.message
-
- thread_title = nil
- thread_title = m.thread&.title if include_thread_titles && m.thread_id
- mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
-
- builder.push(
- type: :user,
- content: mapped_message,
- name: m.user.username,
- upload_ids: upload_ids,
- )
- end
- end
-
- builder.to_a(
- limit: max_messages,
- style: channel.direct_message_channel? ? :chat_with_context : :chat,
- )
- end
-
def reply_to_chat_message(message, channel, context_post_ids)
persona_user = User.find(bot.persona.class.user_id)
@@ -410,10 +260,32 @@ module DiscourseAi
context_post_ids = nil if !channel.direct_message_channel?
+ max_chat_messages = 40
+ if bot.persona.class.respond_to?(:max_context_posts)
+ max_chat_messages = bot.persona.class.max_context_posts || 40
+ end
+
+ if !channel.direct_message_channel?
+ # we are interacting via mentions ... strip mention
+ instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
+ end
+
context =
- get_context(
- participants: participants.join(", "),
- conversation_context: chat_context(message, channel, persona_user, context_post_ids),
+ BotContext.new(
+ participants: participants,
+ message_id: message.id,
+ channel_id: channel.id,
+ context_post_ids: context_post_ids,
+ messages:
+ DiscourseAi::Completions::PromptMessagesBuilder.messages_from_chat(
+ message,
+ channel: channel,
+ context_post_ids: context_post_ids,
+ include_uploads: bot.persona.class.vision_enabled,
+ max_messages: max_chat_messages,
+ bot_user_ids: available_bot_user_ids,
+ instruction_message: instruction_message,
+ ),
user: message.user,
skip_tool_details: true,
)
@@ -460,22 +332,6 @@ module DiscourseAi
reply
end
- def get_context(participants:, conversation_context:, user:, skip_tool_details: nil)
- result = {
- site_url: Discourse.base_url,
- site_title: SiteSetting.title,
- site_description: SiteSetting.site_description,
- time: Time.zone.now,
- participants: participants,
- conversation_context: conversation_context,
- user: user,
- }
-
- result[:skip_tool_details] = true if skip_tool_details
-
- result
- end
-
def reply_to(
post,
custom_instructions: nil,
@@ -509,16 +365,25 @@ module DiscourseAi
end
)
+ # safeguard
+ max_context_posts = 40
+ if bot.persona.class.respond_to?(:max_context_posts)
+ max_context_posts = bot.persona.class.max_context_posts || 40
+ end
+
context =
- get_context(
- participants: post.topic.allowed_users.map(&:username).join(", "),
- conversation_context: conversation_context(post, style: context_style),
- user: post.user,
+ BotContext.new(
+ post: post,
+ custom_instructions: custom_instructions,
+ messages:
+ DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
+ post,
+ style: context_style,
+ max_posts: max_context_posts,
+ include_uploads: bot.persona.class.vision_enabled,
+ bot_usernames: available_bot_usernames,
+ ),
)
- context[:post_id] = post.id
- context[:topic_id] = post.topic_id
- context[:private_message] = post.topic.private_message?
- context[:custom_instructions] = custom_instructions
reply_user = bot.bot_user
if bot.persona.class.respond_to?(:user_id)
@@ -562,7 +427,7 @@ module DiscourseAi
Discourse.redis.setex(redis_stream_key, 60, 1)
end
- context[:skip_tool_details] ||= !bot.persona.class.tool_details
+ context.skip_tool_details ||= !bot.persona.class.tool_details
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
diff --git a/lib/ai_bot/tool_runner.rb b/lib/ai_bot/tool_runner.rb
index df785749..8a42a494 100644
--- a/lib/ai_bot/tool_runner.rb
+++ b/lib/ai_bot/tool_runner.rb
@@ -13,7 +13,13 @@ module DiscourseAi
MARSHAL_STACK_DEPTH = 20
MAX_HTTP_REQUESTS = 20
- def initialize(parameters:, llm:, bot_user:, context: {}, tool:, timeout: nil)
+ def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
+ if context && !context.is_a?(DiscourseAi::AiBot::BotContext)
+ raise ArgumentError, "context must be a BotContext object"
+ end
+
+ context ||= DiscourseAi::AiBot::BotContext.new
+
@parameters = parameters
@llm = llm
@bot_user = bot_user
@@ -99,7 +105,7 @@ module DiscourseAi
},
};
- const context = #{JSON.generate(@context)};
+ const context = #{JSON.generate(@context.to_json)};
function details() { return ""; };
JS
@@ -240,13 +246,13 @@ module DiscourseAi
def llm_user
@llm_user ||=
begin
- @context[:llm_user] || post&.user || @bot_user
+ post&.user || @bot_user
end
end
def post
return @post if defined?(@post)
- post_id = @context[:post_id]
+ post_id = @context.post_id
@post = post_id && Post.find_by(id: post_id)
end
@@ -336,8 +342,8 @@ module DiscourseAi
bot = DiscourseAi::AiBot::Bot.as(@bot_user || persona.user, persona: persona)
playground = DiscourseAi::AiBot::Playground.new(bot)
- if @context[:post_id]
- post = Post.find_by(id: @context[:post_id])
+ if @context.post_id
+ post = Post.find_by(id: @context.post_id)
return { error: "Post not found" } if post.nil?
reply_post =
@@ -354,13 +360,13 @@ module DiscourseAi
else
return { error: "Failed to create reply" }
end
- elsif @context[:message_id] && @context[:channel_id]
- message = Chat::Message.find_by(id: @context[:message_id])
- channel = Chat::Channel.find_by(id: @context[:channel_id])
+ elsif @context.message_id && @context.channel_id
+ message = Chat::Message.find_by(id: @context.message_id)
+ channel = Chat::Channel.find_by(id: @context.channel_id)
return { error: "Message or channel not found" } if message.nil? || channel.nil?
reply =
- playground.reply_to_chat_message(message, channel, @context[:context_post_ids])
+ playground.reply_to_chat_message(message, channel, @context.context_post_ids)
if reply
return { success: true, message_id: reply.id }
@@ -457,7 +463,7 @@ module DiscourseAi
UploadCreator.new(
file,
filename,
- for_private_message: @context[:private_message],
+ for_private_message: @context.private_message,
).create_for(@bot_user.id)
{ id: upload.id, short_url: upload.short_url, url: upload.url }
diff --git a/lib/ai_bot/tools/create_artifact.rb b/lib/ai_bot/tools/create_artifact.rb
index b9175e8d..54221b40 100644
--- a/lib/ai_bot/tools/create_artifact.rb
+++ b/lib/ai_bot/tools/create_artifact.rb
@@ -108,7 +108,7 @@ module DiscourseAi
end
def invoke
- post = Post.find_by(id: context[:post_id])
+ post = Post.find_by(id: context.post_id)
return error_response("No post context found") unless post
partial_response = +""
diff --git a/lib/ai_bot/tools/dall_e.rb b/lib/ai_bot/tools/dall_e.rb
index 7389a3b5..4564cfbb 100644
--- a/lib/ai_bot/tools/dall_e.rb
+++ b/lib/ai_bot/tools/dall_e.rb
@@ -111,7 +111,7 @@ module DiscourseAi
UploadCreator.new(
file,
"image.png",
- for_private_message: context[:private_message],
+ for_private_message: context.private_message?,
).create_for(bot_user.id),
}
end
diff --git a/lib/ai_bot/tools/image.rb b/lib/ai_bot/tools/image.rb
index 72e575c3..34e5e3f2 100644
--- a/lib/ai_bot/tools/image.rb
+++ b/lib/ai_bot/tools/image.rb
@@ -131,7 +131,7 @@ module DiscourseAi
UploadCreator.new(
file,
"image.png",
- for_private_message: context[:private_message],
+ for_private_message: context.private_message,
).create_for(bot_user.id),
seed: image[:seed],
}
diff --git a/lib/ai_bot/tools/read.rb b/lib/ai_bot/tools/read.rb
index 8bb3afc5..d7a0186f 100644
--- a/lib/ai_bot/tools/read.rb
+++ b/lib/ai_bot/tools/read.rb
@@ -48,7 +48,7 @@ module DiscourseAi
def invoke
not_found = { topic_id: topic_id, description: "Topic not found" }
- guardian = Guardian.new(context[:user]) if options[:read_private] && context[:user]
+ guardian = Guardian.new(context.user) if options[:read_private] && context.user
guardian ||= Guardian.new
@title = ""
diff --git a/lib/ai_bot/tools/read_artifact.rb b/lib/ai_bot/tools/read_artifact.rb
index 7fb5646c..98b7944f 100644
--- a/lib/ai_bot/tools/read_artifact.rb
+++ b/lib/ai_bot/tools/read_artifact.rb
@@ -59,7 +59,7 @@ module DiscourseAi
end
def post
- @post ||= Post.find_by(id: context[:post_id])
+ @post ||= Post.find_by(id: context.post_id)
end
def handle_discourse_artifact(uri)
diff --git a/lib/ai_bot/tools/search.rb b/lib/ai_bot/tools/search.rb
index a14e77d6..402eca05 100644
--- a/lib/ai_bot/tools/search.rb
+++ b/lib/ai_bot/tools/search.rb
@@ -130,7 +130,7 @@ module DiscourseAi
after: parameters[:after],
status: parameters[:status],
max_results: max_results,
- current_user: options[:search_private] ? context[:user] : nil,
+ current_user: options[:search_private] ? context.user : nil,
)
@last_num_results = results[:rows]&.length || 0
diff --git a/lib/ai_bot/tools/summarize.rb b/lib/ai_bot/tools/summarize.rb
index 214d45fa..5076635f 100644
--- a/lib/ai_bot/tools/summarize.rb
+++ b/lib/ai_bot/tools/summarize.rb
@@ -40,10 +40,6 @@ module DiscourseAi
false
end
- def standalone?
- true
- end
-
def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end
diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb
index ff95f16e..ba0bc69c 100644
--- a/lib/ai_bot/tools/tool.rb
+++ b/lib/ai_bot/tools/tool.rb
@@ -56,14 +56,17 @@ module DiscourseAi
persona_options: {},
bot_user:,
llm:,
- context: {}
+ context: nil
)
@parameters = parameters
@tool_call_id = tool_call_id
@persona_options = persona_options
@bot_user = bot_user
@llm = llm
- @context = context
+ @context = context.nil? ? DiscourseAi::AiBot::BotContext.new(messages: []) : context
+ if !@context.is_a?(DiscourseAi::AiBot::BotContext)
+ raise ArgumentError, "context must be a DiscourseAi::AiBot::Context"
+ end
end
def name
@@ -108,10 +111,6 @@ module DiscourseAi
true
end
- def standalone?
- false
- end
-
protected
def fetch_default_branch(repo)
diff --git a/lib/ai_bot/tools/update_artifact.rb b/lib/ai_bot/tools/update_artifact.rb
index 4b62bd01..2c9e544b 100644
--- a/lib/ai_bot/tools/update_artifact.rb
+++ b/lib/ai_bot/tools/update_artifact.rb
@@ -39,7 +39,7 @@ module DiscourseAi
def self.inject_prompt(prompt:, context:, persona:)
return if persona.options["do_not_echo_artifact"].to_s == "true"
# we inject the current artifact content into the last user message
- if topic_id = context[:topic_id]
+ if topic_id = context.topic_id
posts = Post.where(topic_id: topic_id)
artifact = AiArtifact.order("id desc").where(post: posts).first
if artifact
@@ -113,7 +113,7 @@ module DiscourseAi
end
def invoke
- post = Post.find_by(id: context[:post_id])
+ post = Post.find_by(id: context.post_id)
return error_response("No post context found") unless post
artifact = AiArtifact.find_by(id: parameters[:artifact_id])
diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb
index 8de1bcf8..c6716aff 100644
--- a/lib/ai_helper/assistant.rb
+++ b/lib/ai_helper/assistant.rb
@@ -188,9 +188,10 @@ module DiscourseAi
messages: [
{
type: :user,
- content:
+ content: [
"Describe this image in a single sentence#{custom_locale_instructions(user)}",
- upload_ids: [upload.id],
+ { upload_id: upload.id },
+ ],
},
],
)
diff --git a/lib/ai_moderation/spam_scanner.rb b/lib/ai_moderation/spam_scanner.rb
index 8569c8e8..e4ed348d 100644
--- a/lib/ai_moderation/spam_scanner.rb
+++ b/lib/ai_moderation/spam_scanner.rb
@@ -187,7 +187,10 @@ module DiscourseAi
prompt = DiscourseAi::Completions::Prompt.new(system_prompt)
args = { type: :user, content: context }
upload_ids = post.upload_ids
- args[:upload_ids] = upload_ids.take(3) if upload_ids.present?
+ if upload_ids.present?
+ args[:content] = [args[:content]]
+ upload_ids.take(3).each { |upload_id| args[:content] << { upload_id: upload_id } }
+ end
prompt.push(**args)
prompt
end
diff --git a/lib/automation/llm_tool_triage.rb b/lib/automation/llm_tool_triage.rb
index 978119a2..7d7622ad 100644
--- a/lib/automation/llm_tool_triage.rb
+++ b/lib/automation/llm_tool_triage.rb
@@ -7,11 +7,7 @@ module DiscourseAi
return if !tool
return if !tool.parameters.blank?
- context = {
- post_id: post.id,
- automation_id: automation&.id,
- automation_name: automation&.name,
- }
+ context = DiscourseAi::AiBot::BotContext.new(post: post)
runner = tool.runner({}, llm: nil, bot_user: Discourse.system_user, context: context)
runner.invoke
diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb
index c3d0d5c1..640e7046 100644
--- a/lib/automation/llm_triage.rb
+++ b/lib/automation/llm_triage.rb
@@ -42,7 +42,12 @@ module DiscourseAi
content = llm.tokenizer.truncate(content, max_post_tokens) if max_post_tokens.present?
- prompt.push(type: :user, content: content, upload_ids: post.upload_ids)
+ if post.upload_ids.present?
+ content = [content]
+ content.concat(post.upload_ids.map { |upload_id| { upload_id: upload_id } })
+ end
+
+ prompt.push(type: :user, content: content)
result = nil
diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb
index 578cbd2b..d2d7245c 100644
--- a/lib/completions/dialects/chat_gpt.rb
+++ b/lib/completions/dialects/chat_gpt.rb
@@ -17,13 +17,13 @@ module DiscourseAi
llm_model.provider == "open_ai" || llm_model.provider == "azure"
end
- def translate
+ def embed_user_ids?
+ return @embed_user_ids if defined?(@embed_user_ids)
+
@embed_user_ids =
prompt.messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end
-
- super
end
def max_prompt_tokens
@@ -102,35 +102,47 @@ module DiscourseAi
end
def user_msg(msg)
- user_message = { role: "user", content: msg[:content] }
+ content_array = []
+
+ user_message = { role: "user" }
if msg[:id]
- if @embed_user_ids
- user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
+ if embed_user_ids?
+ content_array << "#{msg[:id]}: "
else
user_message[:name] = msg[:id]
end
end
- user_message[:content] = inline_images(user_message[:content], msg) if vision_support?
+ content_array << msg[:content]
+
+ content_array =
+ to_encoded_content_array(
+ content: content_array.flatten,
+ image_encoder: ->(details) { image_node(details) },
+ text_encoder: ->(text) { { type: "text", text: text } },
+ allow_vision: vision_support?,
+ )
+
+ user_message[:content] = no_array_if_only_text(content_array)
user_message
end
- def inline_images(content, message)
- encoded_uploads = prompt.encoded_uploads(message)
- return content if encoded_uploads.blank?
+ def no_array_if_only_text(content_array)
+ if content_array.size == 1 && content_array.first[:type] == "text"
+ content_array.first[:text]
+ else
+ content_array
+ end
+ end
- content_w_imgs =
- encoded_uploads.reduce([]) do |memo, details|
- memo << {
- type: "image_url",
- image_url: {
- url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
- },
- }
- end
-
- content_w_imgs << { type: "text", text: message[:content] }
+ def image_node(details)
+ {
+ type: "image_url",
+ image_url: {
+ url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
+ },
+ }
end
def per_message_overhead
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index 06fbe102..1cea2215 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -87,9 +87,9 @@ module DiscourseAi
end
def model_msg(msg)
- if msg[:thinking] || msg[:redacted_thinking_signature]
- content_array = []
+ content_array = []
+ if msg[:thinking] || msg[:redacted_thinking_signature]
if msg[:thinking]
content_array << {
type: "thinking",
@@ -104,13 +104,19 @@ module DiscourseAi
data: msg[:redacted_thinking_signature],
}
end
-
- content_array << { type: "text", text: msg[:content] }
-
- { role: "assistant", content: content_array }
- else
- { role: "assistant", content: msg[:content] }
end
+
+ # other encoder is used to pass through thinking
+ content_array =
+ to_encoded_content_array(
+ content: [content_array, msg[:content]].flatten,
+ image_encoder: ->(details) {},
+ text_encoder: ->(text) { { type: "text", text: text } },
+ other_encoder: ->(details) { details },
+ allow_vision: false,
+ )
+
+ { role: "assistant", content: no_array_if_only_text(content_array) }
end
def system_msg(msg)
@@ -124,31 +130,39 @@ module DiscourseAi
end
def user_msg(msg)
- content = +""
- content << "#{msg[:id]}: " if msg[:id]
- content << msg[:content]
- content = inline_images(content, msg) if vision_support?
+ content_array = []
+ content_array << "#{msg[:id]}: " if msg[:id]
+ content_array.concat([msg[:content]].flatten)
- { role: "user", content: content }
+ content_array =
+ to_encoded_content_array(
+ content: content_array,
+ image_encoder: ->(details) { image_node(details) },
+ text_encoder: ->(text) { { type: "text", text: text } },
+ allow_vision: vision_support?,
+ )
+
+ { role: "user", content: no_array_if_only_text(content_array) }
end
- def inline_images(content, message)
- encoded_uploads = prompt.encoded_uploads(message)
- return content if encoded_uploads.blank?
+ # keeping our payload as backward compatible as possible
+ def no_array_if_only_text(content_array)
+ if content_array.length == 1 && content_array.first[:type] == "text"
+ content_array.first[:text]
+ else
+ content_array
+ end
+ end
- content_w_imgs =
- encoded_uploads.reduce([]) do |memo, details|
- memo << {
- source: {
- type: "base64",
- data: details[:base64],
- media_type: details[:mime_type],
- },
- type: "image",
- }
- end
-
- content_w_imgs << { type: "text", text: content }
+ def image_node(details)
+ {
+ source: {
+ type: "base64",
+ data: details[:base64],
+ media_type: details[:mime_type],
+ },
+ type: "image",
+ }
end
end
end
diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb
index 25561433..6af4b10f 100644
--- a/lib/completions/dialects/command.rb
+++ b/lib/completions/dialects/command.rb
@@ -110,9 +110,9 @@ module DiscourseAi
end
def user_msg(msg)
- user_message = { role: "USER", message: msg[:content] }
- user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
-
+ content = prompt.text_only(msg)
+ user_message = { role: "USER", message: content }
+ user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id]
user_message
end
end
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
index 2a335a12..5d84f6d5 100644
--- a/lib/completions/dialects/dialect.rb
+++ b/lib/completions/dialects/dialect.rb
@@ -227,6 +227,38 @@ module DiscourseAi
msg = msg.merge(content: new_content)
user_msg(msg)
end
+
+ def to_encoded_content_array(
+ content:,
+ image_encoder:,
+ text_encoder:,
+ other_encoder: nil,
+ allow_vision:
+ )
+ content = [content] if !content.is_a?(Array)
+
+ current_string = +""
+ result = []
+
+ content.each do |c|
+ if c.is_a?(String)
+ current_string << c
+ elsif c.is_a?(Hash) && c.key?(:upload_id) && allow_vision
+ if !current_string.empty?
+ result << text_encoder.call(current_string)
+ current_string = +""
+ end
+ encoded = prompt.encode_upload(c[:upload_id])
+ result << image_encoder.call(encoded) if encoded
+ elsif other_encoder
+ encoded = other_encoder.call(c)
+ result << encoded if encoded
+ end
+ end
+
+ result << text_encoder.call(current_string) if !current_string.empty?
+ result
+ end
end
end
end
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index 563ed88b..050fabd2 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -106,28 +106,29 @@ module DiscourseAi
end
def user_msg(msg)
- if beta_api?
- # support new format with multiple parts
- result = { role: "user", parts: [{ text: msg[:content] }] }
- return result unless vision_support?
+ content_array = []
+ content_array << "#{msg[:id]}: " if msg[:id]
- upload_parts = uploaded_parts(msg)
- result[:parts].concat(upload_parts) if upload_parts.present?
- result
+ content_array << msg[:content]
+ content_array.flatten!
+
+ content_array =
+ to_encoded_content_array(
+ content: content_array,
+ image_encoder: ->(details) { image_node(details) },
+ text_encoder: ->(text) { { text: text } },
+ allow_vision: vision_support? && beta_api?,
+ )
+
+ if beta_api?
+ { role: "user", parts: content_array }
else
- { role: "user", parts: { text: msg[:content] } }
+ { role: "user", parts: content_array.first }
end
end
- def uploaded_parts(message)
- encoded_uploads = prompt.encoded_uploads(message)
- result = []
- if encoded_uploads.present?
- encoded_uploads.each do |details|
- result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
- end
- end
- result
+ def image_node(details)
+ { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
end
def tool_call_msg(msg)
diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb
index aa184a7a..b078e79d 100644
--- a/lib/completions/dialects/nova.rb
+++ b/lib/completions/dialects/nova.rb
@@ -155,7 +155,7 @@ module DiscourseAi
end
end
- { role: "user", content: msg[:content], images: images }
+ { role: "user", content: prompt.text_only(msg), images: images }
end
def model_msg(msg)
diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb
index 4546c827..fe31bc1d 100644
--- a/lib/completions/dialects/ollama.rb
+++ b/lib/completions/dialects/ollama.rb
@@ -69,7 +69,7 @@ module DiscourseAi
end
def user_msg(msg)
- user_message = { role: "user", content: msg[:content] }
+ user_message = { role: "user", content: prompt.text_only(msg) }
encoded_uploads = prompt.encoded_uploads(msg)
if encoded_uploads.present?
diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb
index 2d648ac1..d6519822 100644
--- a/lib/completions/dialects/open_ai_compatible.rb
+++ b/lib/completions/dialects/open_ai_compatible.rb
@@ -3,7 +3,7 @@
module DiscourseAi
module Completions
module Dialects
- class OpenAiCompatible < Dialect
+ class OpenAiCompatible < ChatGpt
class << self
def can_translate?(_llm_model)
# fallback dialect
@@ -43,58 +43,6 @@ module DiscourseAi
translated.unshift(user_msg)
end
-
- private
-
- def system_msg(msg)
- msg = { role: "system", content: msg[:content] }
-
- if tools_dialect.instructions.present?
- msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
- end
-
- msg
- end
-
- def model_msg(msg)
- { role: "assistant", content: msg[:content] }
- end
-
- def tool_call_msg(msg)
- translated = tools_dialect.from_raw_tool_call(msg)
- { role: "assistant", content: translated }
- end
-
- def tool_msg(msg)
- translated = tools_dialect.from_raw_tool(msg)
- { role: "user", content: translated }
- end
-
- def user_msg(msg)
- content = +""
- content << "#{msg[:id]}: " if msg[:id]
- content << msg[:content]
-
- message = { role: "user", content: content }
-
- message[:content] = inline_images(message[:content], msg) if vision_support?
-
- message
- end
-
- def inline_images(content, message)
- encoded_uploads = prompt.encoded_uploads(message)
- return content if encoded_uploads.blank?
-
- encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
- memo << {
- type: "image_url",
- image_url: {
- url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
- },
- }
- end
- end
end
end
end
diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb
index 79e71c06..660b1ca5 100644
--- a/lib/completions/prompt.rb
+++ b/lib/completions/prompt.rb
@@ -89,7 +89,6 @@ module DiscourseAi
content:,
id: nil,
name: nil,
- upload_ids: nil,
thinking: nil,
thinking_signature: nil,
redacted_thinking_signature: nil
@@ -98,7 +97,6 @@ module DiscourseAi
new_message = { type: type, content: content }
new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id
- new_message[:upload_ids] = upload_ids if upload_ids
new_message[:thinking] = thinking if thinking
new_message[:thinking_signature] = thinking_signature if thinking_signature
new_message[
@@ -115,11 +113,44 @@ module DiscourseAi
tools.present?
end
- # helper method to get base64 encoded uploads
- # at the correct dimentions
def encoded_uploads(message)
- return [] if message[:upload_ids].blank?
- UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
+ if message[:content].is_a?(Array)
+ upload_ids =
+ message[:content]
+ .map do |content|
+ content[:upload_id] if content.is_a?(Hash) && content.key?(:upload_id)
+ end
+ .compact
+ if !upload_ids.empty?
+ return UploadEncoder.encode(upload_ids: upload_ids, max_pixels: max_pixels)
+ end
+ end
+
+ []
+ end
+
+ def text_only(message)
+ if message[:content].is_a?(Array)
+ message[:content].map { |element| element if element.is_a?(String) }.compact.join
+ else
+ message[:content]
+ end
+ end
+
+ def encode_upload(upload_id)
+ UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first
+ end
+
+ def content_with_encoded_uploads(content)
+ return [content] unless content.is_a?(Array)
+
+ content.map do |c|
+ if c.is_a?(Hash) && c.key?(:upload_id)
+ encode_upload(c[:upload_id])
+ else
+ c
+ end
+ end
end
def ==(other)
@@ -150,7 +181,6 @@ module DiscourseAi
content
id
name
- upload_ids
thinking
thinking_signature
redacted_thinking_signature
@@ -159,15 +189,17 @@ module DiscourseAi
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end
- if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
- raise ArgumentError, "upload_ids must be an array of ids"
+ if message[:content].is_a?(Array)
+ message[:content].each do |content|
+ if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id])
+ raise ArgumentError, "Array message content must be a string or {upload_id: ...} "
+ end
+ end
+ else
+ if !message[:content].is_a?(String)
+ raise ArgumentError, "Message content must be a string or an array"
+ end
end
-
- if message[:upload_ids].present? && message[:type] != :user
- raise ArgumentError, "upload_ids are only supported for users"
- end
-
- raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
end
def validate_turn(last_turn, new_turn)
diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb
index 26fb24fa..b50b5b67 100644
--- a/lib/completions/prompt_messages_builder.rb
+++ b/lib/completions/prompt_messages_builder.rb
@@ -9,6 +9,154 @@ module DiscourseAi
attr_reader :chat_context_post_upload_ids
attr_accessor :topic
+ def self.messages_from_chat(
+ message,
+ channel:,
+ context_post_ids:,
+ max_messages:,
+ include_uploads:,
+ bot_user_ids:,
+ instruction_message: nil
+ )
+ include_thread_titles = !channel.direct_message_channel? && !message.thread_id
+
+ current_id = message.id
+ messages = nil
+
+ if !message.thread_id && channel.direct_message_channel?
+ messages = [message]
+ elsif !channel.direct_message_channel? && !message.thread_id
+ messages =
+ Chat::Message
+ .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
+ .where(chat_channel_id: channel.id)
+ .where(
+ "chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
+ )
+ .order(id: :desc)
+ .limit(max_messages)
+ .to_a
+ .reverse
+ end
+
+ messages ||=
+ ChatSDK::Thread.last_messages(
+ thread_id: message.thread_id,
+ guardian: Discourse.system_user.guardian,
+ page_size: max_messages,
+ )
+
+ builder = new
+
+ guardian = Guardian.new(message.user)
+ if context_post_ids
+ builder.set_chat_context_posts(
+ context_post_ids,
+ guardian,
+ include_uploads: include_uploads,
+ )
+ end
+
+ messages.each do |m|
+ # restore stripped message
+ m.message = instruction_message if m.id == current_id && instruction_message
+
+ if bot_user_ids.include?(m.user_id)
+ builder.push(type: :model, content: m.message)
+ else
+ upload_ids = nil
+ upload_ids = m.uploads.map(&:id) if include_uploads && m.uploads.present?
+ mapped_message = m.message
+
+ thread_title = nil
+ thread_title = m.thread&.title if include_thread_titles && m.thread_id
+ mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
+
+ builder.push(
+ type: :user,
+ content: mapped_message,
+ name: m.user.username,
+ upload_ids: upload_ids,
+ )
+ end
+ end
+
+ builder.to_a(
+ limit: max_messages,
+ style: channel.direct_message_channel? ? :chat_with_context : :chat,
+ )
+ end
+
+ def self.messages_from_post(post, style: nil, max_posts:, bot_usernames:, include_uploads:)
+ # Pay attention to the `post_number <= ?` here.
+ # We want to inject the last post as context because they are translated differently.
+
+ post_types = [Post.types[:regular]]
+ post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
+
+ context =
+ post
+ .topic
+ .posts
+ .joins(:user)
+ .joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
+ .where("post_number <= ?", post.post_number)
+ .order("post_number desc")
+ .where("post_type in (?)", post_types)
+ .limit(max_posts)
+ .pluck(
+ "posts.raw",
+ "users.username",
+ "post_custom_prompts.custom_prompt",
+ "(
+ SELECT array_agg(ref.upload_id)
+ FROM upload_references ref
+ WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
+ ) as upload_ids",
+ )
+
+ builder = new
+ builder.topic = post.topic
+
+ context.reverse_each do |raw, username, custom_prompt, upload_ids|
+ custom_prompt_translation =
+ Proc.new do |message|
+ # We can't keep backwards-compatibility for stored functions.
+ # Tool syntax requires a tool_call_id which we don't have.
+ if message[2] != "function"
+ custom_context = {
+ content: message[0],
+ type: message[2].present? ? message[2].to_sym : :model,
+ }
+
+ custom_context[:id] = message[1] if custom_context[:type] != :model
+ custom_context[:name] = message[3] if message[3]
+
+ thinking = message[4]
+ custom_context[:thinking] = thinking if thinking
+
+ builder.push(**custom_context)
+ end
+ end
+
+ if custom_prompt.present?
+ custom_prompt.each(&custom_prompt_translation)
+ else
+ context = { content: raw, type: (bot_usernames.include?(username) ? :model : :user) }
+
+ context[:id] = username if context[:type] == :user
+
+ if upload_ids.present? && context[:type] == :user && include_uploads
+ context[:upload_ids] = upload_ids.compact
+ end
+
+ builder.push(**context)
+ end
+ end
+
+ builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
+ end
+
def initialize
@raw_messages = []
end
@@ -68,12 +216,19 @@ module DiscourseAi
if message[:type] == :user
old_name = last_message.delete(:name)
- last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name
+ last_message[:content] = ["#{old_name}: ", last_message[:content]].flatten if old_name
new_content = message[:content]
- new_content = "#{message[:name]}: #{new_content}" if message[:name]
+ new_content = ["#{message[:name]}: ", new_content].flatten if message[:name]
- last_message[:content] += "\n#{new_content}"
+ if !last_message[:content].is_a?(Array)
+ last_message[:content] = [last_message[:content]]
+ end
+ last_message[:content].concat(["\n", new_content].flatten)
+
+ compressed =
+ compress_messages_buffer(last_message[:content], max_uploads: MAX_TOPIC_UPLOADS)
+ last_message[:content] = compressed
else
last_message[:content] = message[:content]
end
@@ -111,9 +266,9 @@ module DiscourseAi
end
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
+ content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids
message = { type: type, content: content }
message[:name] = name.to_s if name
- message[:upload_ids] = upload_ids if upload_ids
message[:id] = id.to_s if id
if thinking
message[:thinking] = thinking["thinking"] if thinking["thinking"]
@@ -132,67 +287,62 @@ module DiscourseAi
def topic_array
raw_messages = @raw_messages.dup
- user_content = +"You are operating in a Discourse forum.\n\n"
+ content_array = []
+ content_array << "You are operating in a Discourse forum.\n\n"
if @topic
if @topic.private_message?
- user_content << "Private message info.\n"
+ content_array << "Private message info.\n"
else
- user_content << "Topic information:\n"
+ content_array << "Topic information:\n"
end
- user_content << "- URL: #{@topic.url}\n"
- user_content << "- Title: #{@topic.title}\n"
+ content_array << "- URL: #{@topic.url}\n"
+ content_array << "- Title: #{@topic.title}\n"
if SiteSetting.tagging_enabled
tags = @topic.tags.pluck(:name)
tags -= DiscourseTagging.hidden_tag_names if tags.present?
- user_content << "- Tags: #{tags.join(", ")}\n" if tags.present?
+ content_array << "- Tags: #{tags.join(", ")}\n" if tags.present?
end
if !@topic.private_message?
- user_content << "- Category: #{@topic.category.name}\n" if @topic.category
+ content_array << "- Category: #{@topic.category.name}\n" if @topic.category
end
- user_content << "- Number of replies: #{@topic.posts_count - 1}\n\n"
+ content_array << "- Number of replies: #{@topic.posts_count - 1}\n\n"
end
last_user_message = raw_messages.pop
- upload_ids = []
if raw_messages.present?
- user_content << "Here is the conversation so far:\n"
+ content_array << "Here is the conversation so far:\n"
raw_messages.each do |message|
- user_content << "#{message[:name] || "User"}: #{message[:content]}\n"
- upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
+ content_array << "#{message[:name] || "User"}: "
+ content_array << message[:content]
+ content_array << "\n\n"
end
end
if last_user_message
- user_content << "You are responding to #{last_user_message[:name] || "User"} who just said:\n #{last_user_message[:content]}"
- if last_user_message[:upload_ids].present?
- upload_ids.concat(last_user_message[:upload_ids])
- end
+ content_array << "You are responding to #{last_user_message[:name] || "User"} who just said:\n"
+ content_array << last_user_message[:content]
end
- user_message = { type: :user, content: user_content }
+ content_array =
+ compress_messages_buffer(content_array.flatten, max_uploads: MAX_TOPIC_UPLOADS)
- if upload_ids.present?
- user_message[:upload_ids] = upload_ids[-MAX_TOPIC_UPLOADS..-1] || upload_ids
- end
+ user_message = { type: :user, content: content_array }
[user_message]
end
def chat_array(limit:)
if @raw_messages.length > 1
- buffer =
- +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{"
-
- upload_ids = []
+ buffer = [
+ +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{",
+ ]
@raw_messages[0..-2].each do |message|
buffer << "\n"
- upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
-
if message[:type] == :user
buffer << "#{message[:name] || "User"}: "
else
@@ -209,16 +359,44 @@ module DiscourseAi
end
last_message = @raw_messages[-1]
- buffer << "#{last_message[:name] || "User"}: #{last_message[:content]} "
+ buffer << "#{last_message[:name] || "User"}: "
+ buffer << last_message[:content]
+
+ buffer = compress_messages_buffer(buffer.flatten, max_uploads: MAX_CHAT_UPLOADS)
message = { type: :user, content: buffer }
- upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present?
-
- message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] ||
- upload_ids if upload_ids.present?
-
[message]
end
+
+ # caps uploads to maximum uploads allowed in message stream
+ # and concats string elements
+ def compress_messages_buffer(buffer, max_uploads:)
+ compressed = []
+ current_text = +""
+ upload_count = 0
+
+ buffer.each do |item|
+ if item.is_a?(String)
+ current_text << item
+ elsif item.is_a?(Hash)
+ compressed << current_text if current_text.present?
+ compressed << item
+ current_text = +""
+ upload_count += 1
+ end
+ end
+
+ compressed << current_text if current_text.present?
+
+ if upload_count > max_uploads
+ counter = max_uploads - upload_count
+ compressed.delete_if { |item| item.is_a?(Hash) && (counter += 1) > 0 }
+ end
+
+ compressed = compressed[0] if compressed.length == 1 && compressed[0].is_a?(String)
+
+ compressed
+ end
end
end
end
diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb
index a94ae0e6..860fe653 100644
--- a/spec/lib/completions/dialects/gemini_spec.rb
+++ b/spec/lib/completions/dialects/gemini_spec.rb
@@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(context.image_generation_scenario).to eq(
{
messages: [
- { role: "user", parts: [{ text: "draw a cat" }] },
+ { role: "user", parts: [{ text: "user1: draw a cat" }] },
{
role: "model",
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
@@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
],
},
{ role: "model", parts: { text: "Ok." } },
- { role: "user", parts: [{ text: "draw another cat" }] },
+ { role: "user", parts: [{ text: "user1: draw another cat" }] },
],
system_instruction: context.system_insts,
},
@@ -52,12 +52,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(context.multi_turn_scenario).to eq(
{
messages: [
- { role: "user", parts: [{ text: "This is a message by a user" }] },
+ { role: "user", parts: [{ text: "user1: This is a message by a user" }] },
{
role: "model",
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
},
- { role: "user", parts: [{ text: "This is a new message by a user" }] },
+ { role: "user", parts: [{ text: "user1: This is a new message by a user" }] },
{
role: "model",
parts: [
diff --git a/spec/lib/completions/dialects/mistral_spec.rb b/spec/lib/completions/dialects/mistral_spec.rb
index 2e373bc5..a929768b 100644
--- a/spec/lib/completions/dialects/mistral_spec.rb
+++ b/spec/lib/completions/dialects/mistral_spec.rb
@@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
- messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
+ messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
@@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
content = dialect.translate[1][:content]
expect(content).to eq(
- [{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }],
+ [{ type: "text", text: "user1: hello" }, { type: "image_url", image_url: { url: image } }],
)
end
diff --git a/spec/lib/completions/dialects/nova_spec.rb b/spec/lib/completions/dialects/nova_spec.rb
index 865426e2..36b0fcc7 100644
--- a/spec/lib/completions/dialects/nova_spec.rb
+++ b/spec/lib/completions/dialects/nova_spec.rb
@@ -37,7 +37,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Nova do
it "properly formats messages with images" do
messages = [
- { type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] },
+ {
+ type: :user,
+ id: "user1",
+ content: ["What's in this image?", { upload_id: upload.id }],
+ },
]
prompt = DiscourseAi::Completions::Prompt.new(messages: messages)
diff --git a/spec/lib/completions/dialects/open_ai_compatible_spec.rb b/spec/lib/completions/dialects/open_ai_compatible_spec.rb
index 00dea4b0..185a97e3 100644
--- a/spec/lib/completions/dialects/open_ai_compatible_spec.rb
+++ b/spec/lib/completions/dialects/open_ai_compatible_spec.rb
@@ -34,8 +34,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
messages: [
{
type: :user,
- content: "Describe this image in a single sentence.",
- upload_ids: [upload.id],
+ content: ["Describe this image in a single sentence.", { upload_id: upload.id }],
},
],
)
@@ -49,10 +48,15 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
expect(translated_messages.length).to eq(1)
+ # no system message support here
expected_user_message = {
role: "user",
content: [
- { type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") },
+ {
+ type: "text",
+ text:
+ "You are a bot specializing in image captioning.\nDescribe this image in a single sentence.",
+ },
{
type: "image_url",
image_url: {
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index f2f79c5f..ca3bcb46 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -271,7 +271,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
- messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
+ messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
@@ -283,6 +283,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
{
role: "user",
content: [
+ { type: "text", text: "user1: hello" },
{
type: "image",
source: {
@@ -291,7 +292,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: encoded[0][:base64],
},
},
- { type: "text", text: "user1: hello" },
],
},
],
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index fe7f4eb6..05a87b41 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -211,7 +211,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
- messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
+ messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
@@ -248,7 +248,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{
"role" => "user",
"parts" => [
- { "text" => "hello" },
+ { "text" => "user1: hello" },
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
],
},
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index d48bffb5..fc3b8b48 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -492,7 +492,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
- messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
+ messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
@@ -517,13 +517,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
{
role: "user",
content: [
+ { type: "text", text: "hello" },
{
type: "image_url",
image_url: {
url: "data:#{encoded[0][:mime_type]};base64,#{encoded[0][:base64]}",
},
},
- { type: "text", text: "hello" },
],
name: "user1",
},
diff --git a/spec/lib/completions/prompt_messages_builder_spec.rb b/spec/lib/completions/prompt_messages_builder_spec.rb
index b162e39c..899b70e4 100644
--- a/spec/lib/completions/prompt_messages_builder_spec.rb
+++ b/spec/lib/completions/prompt_messages_builder_spec.rb
@@ -2,6 +2,36 @@
describe DiscourseAi::Completions::PromptMessagesBuilder do
let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new }
+ fab!(:user)
+ fab!(:admin)
+ fab!(:bot_user) { Fabricate(:user) }
+ fab!(:other_user) { Fabricate(:user) }
+
+ fab!(:image_upload1) do
+ Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
+ end
+ fab!(:image_upload2) do
+ Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
+ end
+
+ it "correctly merges user messages with uploads" do
+ builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1])
+ builder.push(type: :user, content: "World", name: "Bob", upload_ids: [2])
+
+ messages = builder.to_a
+
+ # Check the structure of the merged message
+ expect(messages.length).to eq(1)
+ expect(messages[0][:type]).to eq(:user)
+
+ # The content should contain the text and both uploads
+ content = messages[0][:content]
+ expect(content).to be_an(Array)
+ expect(content[0]).to eq("Alice: Hello")
+ expect(content[1]).to eq({ upload_id: 1 })
+ expect(content[2]).to eq("\nBob: World")
+ expect(content[3]).to eq({ upload_id: 2 })
+ end
it "should allow merging user messages" do
builder.push(type: :user, content: "Hello", name: "Alice")
@@ -14,7 +44,7 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2])
expect(builder.to_a).to eq(
- [{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }],
+ [{ type: :user, content: ["Hello", { upload_id: 1 }, { upload_id: 2 }], name: "Alice" }],
)
end
@@ -64,4 +94,319 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
expect(content).to include("Alice")
expect(content).to include("How do I solve this")
end
+
+ describe ".messages_from_chat" do
+ fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) }
+ fab!(:dm_message1) do
+ Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "Hello bot")
+ end
+ fab!(:dm_message2) do
+ Fabricate(:chat_message, chat_channel: dm_channel, user: bot_user, message: "Hello human")
+ end
+ fab!(:dm_message3) do
+ Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "How are you?")
+ end
+
+ fab!(:public_channel) { Fabricate(:category_channel) }
+ fab!(:public_message1) do
+ Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Hello everyone")
+ end
+ fab!(:public_message2) do
+ Fabricate(:chat_message, chat_channel: public_channel, user: bot_user, message: "Hi there")
+ end
+
+ fab!(:thread_original) do
+ Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Thread starter")
+ end
+ fab!(:thread) do
+ Fabricate(:chat_thread, channel: public_channel, original_message: thread_original)
+ end
+ fab!(:thread_reply1) do
+ Fabricate(
+ :chat_message,
+ chat_channel: public_channel,
+ user: other_user,
+ message: "Thread reply",
+ thread: thread,
+ )
+ end
+
+ fab!(:upload) { Fabricate(:upload, user: user) }
+ fab!(:message_with_upload) do
+ Fabricate(
+ :chat_message,
+ chat_channel: dm_channel,
+ user: user,
+ message: "Check this image",
+ upload_ids: [upload.id],
+ )
+ end
+
+ it "processes messages from direct message channels" do
+ context =
+ described_class.messages_from_chat(
+ dm_message3,
+ channel: dm_channel,
+ context_post_ids: nil,
+ max_messages: 10,
+ include_uploads: false,
+ bot_user_ids: [bot_user.id],
+ instruction_message: nil,
+ )
+
+ # this is all we got cause it is assuming threading
+ expect(context).to eq([{ type: :user, content: "How are you?", name: user.username }])
+ end
+
+ it "includes uploads when include_uploads is true" do
+ message_with_upload.reload
+ expect(message_with_upload.uploads).to include(upload)
+
+ context =
+ described_class.messages_from_chat(
+ message_with_upload,
+ channel: dm_channel,
+ context_post_ids: nil,
+ max_messages: 10,
+ include_uploads: true,
+ bot_user_ids: [bot_user.id],
+ instruction_message: nil,
+ )
+
+ # Find the message with upload
+ message = context.find { |m| m[:content] == ["Check this image", { upload_id: upload.id }] }
+ expect(message).to be_present
+ end
+
+ it "doesn't include uploads when include_uploads is false" do
+ # Make sure the upload is associated with the message
+ message_with_upload.reload
+ expect(message_with_upload.uploads).to include(upload)
+
+ context =
+ described_class.messages_from_chat(
+ message_with_upload,
+ channel: dm_channel,
+ context_post_ids: nil,
+ max_messages: 10,
+ include_uploads: false,
+ bot_user_ids: [bot_user.id],
+ instruction_message: nil,
+ )
+
+ # Find the message with upload
+ message = context.find { |m| m[:content] == "Check this image" }
+ expect(message).to be_present
+ expect(message[:upload_ids]).to be_nil
+ end
+
+ it "properly handles uploads in public channels with multiple users" do
+ _first_message =
+ Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "First message")
+
+ _message_with_upload =
+ Fabricate(
+ :chat_message,
+ chat_channel: public_channel,
+ user: other_user,
+ message: "Message with image",
+ upload_ids: [upload.id],
+ )
+
+ last_message =
+ Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Final message")
+
+ context =
+ described_class.messages_from_chat(
+ last_message,
+ channel: public_channel,
+ context_post_ids: nil,
+ max_messages: 3,
+ include_uploads: true,
+ bot_user_ids: [bot_user.id],
+ instruction_message: nil,
+ )
+
+ expect(context.length).to eq(1)
+ content = context.first[:content]
+
+ expect(content.length).to eq(3)
+ expect(content[0]).to include("First message")
+ expect(content[0]).to include("Message with image")
+ expect(content[1]).to include({ upload_id: upload.id })
+ expect(content[2]).to include("Final message")
+ end
+ end
+
+ describe ".messages_from_post" do
+ fab!(:pm) do
+ Fabricate(
+ :private_message_topic,
+ title: "This is my special PM",
+ user: user,
+ topic_allowed_users: [
+ Fabricate.build(:topic_allowed_user, user: user),
+ Fabricate.build(:topic_allowed_user, user: bot_user),
+ ],
+ )
+ end
+ fab!(:first_post) do
+ Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user")
+ end
+ fab!(:second_post) do
+ Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply")
+ end
+ fab!(:third_post) do
+ Fabricate(
+ :post,
+ topic: pm,
+ user: user,
+ post_number: 3,
+ raw: "This is a second reply by the user",
+ )
+ end
+
+ it "handles uploads correctly in topic style messages" do
+ # Use Discourse's upload format in the post raw content
+ upload_markdown = ""
+
+ post_with_upload =
+ Fabricate(
+ :post,
+ topic: pm,
+ user: admin,
+ raw: "This is the original #{upload_markdown} I just added",
+ )
+
+ UploadReference.create!(target: post_with_upload, upload: image_upload1)
+
+ upload2_markdown = ""
+
+ post2_with_upload =
+ Fabricate(
+ :post,
+ topic: pm,
+ user: admin,
+ raw: "This post has a different image #{upload2_markdown} I just added",
+ )
+
+ UploadReference.create!(target: post2_with_upload, upload: image_upload2)
+
+ messages =
+ described_class.messages_from_post(
+ post2_with_upload,
+ style: :topic,
+ max_posts: 3,
+ bot_usernames: [bot_user.username],
+ include_uploads: true,
+ )
+
+ # this is not quite ideal yet, images are attached at the end of the post
+ # long term we may want to extract them out using a regex and create N parts
+ # so people can talk about multiple images in a single post
+ # this is the initial ground work though
+
+ expect(messages.length).to eq(1)
+ content = messages[0][:content]
+
+ # first part
+ # first image
+ # second part
+ # second image
+ expect(content.length).to eq(4)
+ expect(content[0]).to include("This is the original")
+ expect(content[1]).to eq({ upload_id: image_upload1.id })
+ expect(content[2]).to include("different image")
+ expect(content[3]).to eq({ upload_id: image_upload2.id })
+ end
+
+ context "with limited context" do
+ it "respects max_context_posts" do
+ context =
+ described_class.messages_from_post(
+ third_post,
+ max_posts: 1,
+ bot_usernames: [bot_user.username],
+ include_uploads: false,
+ )
+
+ expect(context).to contain_exactly(
+ *[{ type: :user, id: user.username, content: third_post.raw }],
+ )
+ end
+ end
+
+ it "includes previous posts ordered by post_number" do
+ context =
+ described_class.messages_from_post(
+ third_post,
+ max_posts: 10,
+ bot_usernames: [bot_user.username],
+ include_uploads: false,
+ )
+
+ expect(context).to eq(
+ [
+ { type: :user, content: "This is a reply by the user", id: user.username },
+ { type: :model, content: "This is a bot reply" },
+ { type: :user, content: "This is a second reply by the user", id: user.username },
+ ],
+ )
+ end
+
+ it "only include regular posts" do
+ first_post.update!(post_type: Post.types[:whisper])
+
+ context =
+ described_class.messages_from_post(
+ third_post,
+ max_posts: 10,
+ bot_usernames: [bot_user.username],
+ include_uploads: false,
+ )
+
+ # skips leading model reply which makes no sense cause first post was whisper
+ expect(context).to eq(
+ [{ type: :user, content: "This is a second reply by the user", id: user.username }],
+ )
+ end
+
+ context "with custom prompts" do
+ it "When post custom prompt is present, we use that instead of the post content" do
+ custom_prompt = [
+ [
+ { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
+ "time",
+ "tool_call",
+ ],
+ [
+ { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
+ "time",
+ "tool",
+ ],
+ ["I replied to the time command", bot_user.username],
+ ]
+
+ PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
+
+ context =
+ described_class.messages_from_post(
+ third_post,
+ max_posts: 10,
+ bot_usernames: [bot_user.username],
+ include_uploads: false,
+ )
+
+ expect(context).to eq(
+ [
+ { type: :user, content: "This is a reply by the user", id: user.username },
+ { type: :tool_call, content: custom_prompt.first.first, id: "time" },
+ { type: :tool, id: "time", content: custom_prompt.second.first },
+ { type: :model, content: custom_prompt.third.first },
+ { type: :user, content: "This is a second reply by the user", id: user.username },
+ ],
+ )
+ end
+ end
+ end
end
diff --git a/spec/lib/completions/prompt_spec.rb b/spec/lib/completions/prompt_spec.rb
index fe2d1fa1..dafae3bd 100644
--- a/spec/lib/completions/prompt_spec.rb
+++ b/spec/lib/completions/prompt_spec.rb
@@ -25,34 +25,21 @@ RSpec.describe DiscourseAi::Completions::Prompt do
end
describe "image support" do
- it "allows adding uploads to messages" do
+ it "allows adding uploads inline in messages" do
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
prompt.max_pixels = 300
- prompt.push(type: :user, content: "hello", upload_ids: [upload.id])
+ prompt.push(
+ type: :user,
+ content: ["this is an image", { upload_id: upload.id }, "this was an image"],
+ )
- expect(prompt.messages.last[:upload_ids]).to eq([upload.id])
- expect(prompt.max_pixels).to eq(300)
+ encoded = prompt.content_with_encoded_uploads(prompt.messages.last[:content])
- encoded = prompt.encoded_uploads(prompt.messages.last)
-
- expect(encoded.length).to eq(1)
- expect(encoded[0][:mime_type]).to eq("image/jpeg")
-
- old_base64 = encoded[0][:base64]
-
- prompt.max_pixels = 1_000_000
-
- encoded = prompt.encoded_uploads(prompt.messages.last)
-
- expect(encoded.length).to eq(1)
- expect(encoded[0][:mime_type]).to eq("image/jpeg")
-
- new_base64 = encoded[0][:base64]
-
- expect(new_base64.length).to be > old_base64.length
- expect(new_base64.length).to be > 0
- expect(old_base64.length).to be > 0
+ expect(encoded.length).to eq(3)
+ expect(encoded[0]).to eq("this is an image")
+ expect(encoded[1][:mime_type]).to eq("image/jpeg")
+ expect(encoded[2]).to eq("this was an image")
end
end
diff --git a/spec/lib/modules/ai_bot/bot_spec.rb b/spec/lib/modules/ai_bot/bot_spec.rb
index 3432ad28..025a37f9 100644
--- a/spec/lib/modules/ai_bot/bot_spec.rb
+++ b/spec/lib/modules/ai_bot/bot_spec.rb
@@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new)
bot.reply(
- { conversation_context: [{ type: :user, content: "test" }] },
+ DiscourseAi::AiBot::BotContext.new(messages: [{ type: :user, content: "test" }]),
) do |_partial, _cancel, _placeholder|
# we just need the block so bot has something to call with results
end
@@ -74,7 +74,10 @@ RSpec.describe DiscourseAi::AiBot::Bot do
HTML
- context = { conversation_context: [{ type: :user, content: "Does my site has tags?" }] }
+ context =
+ DiscourseAi::AiBot::BotContext.new(
+ messages: [{ type: :user, content: "Does my site has tags?" }],
+ )
DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do
bot.reply(context) do |_bot_reply_post, cancel, placeholder|
diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb
index 8949ae13..8bd6634d 100644
--- a/spec/lib/modules/ai_bot/personas/persona_spec.rb
+++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb
@@ -36,13 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
let(:context) do
- {
+ DiscourseAi::AiBot::BotContext.new(
site_url: Discourse.base_url,
site_title: "test site title",
site_description: "test site description",
time: Time.zone.now,
participants: topic_with_users.allowed_users.map(&:username).join(", "),
- }
+ )
end
fab!(:admin)
@@ -307,7 +307,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
let(:with_cc) do
- context.merge(conversation_context: [{ content: "Tell me the time", type: :user }])
+ context.messages = [{ content: "Tell me the time", type: :user }]
+ context
end
context "when a persona has no uploads" do
@@ -345,17 +346,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
# this means that we will consolidate
- ctx =
- with_cc.merge(
- conversation_context: [
- { content: "Tell me the time", type: :user },
- { content: "the time is 1", type: :model },
- { content: "in france?", type: :user },
- ],
- )
+ context.messages = [
+ { content: "Tell me the time", type: :user },
+ { content: "the time is 1", type: :model },
+ { content: "in france?", type: :user },
+ ]
DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do
- custom_persona.craft_prompt(ctx).messages.first[:content]
+ custom_persona.craft_prompt(context).messages.first[:content]
end
message =
@@ -397,7 +395,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
EmbeddingsGenerationStubs.hugging_face_service(
- with_cc.dig(:conversation_context, 0, :content),
+ with_cc.messages.dig(0, :content),
prompt_cc_embeddings,
)
end
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index d65842e0..4df2ae5d 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -267,7 +267,10 @@ RSpec.describe DiscourseAi::AiBot::Playground do
prompts = inner_prompts
end
- expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id])
+ content = prompts[0].messages[1][:content]
+
+ expect(content).to include({ upload_id: upload.id })
+
expect(prompts[0].max_pixels).to eq(1000)
post.topic.reload
@@ -1154,79 +1157,4 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(playground.available_bot_usernames).to include(persona.user.username)
end
end
-
- describe "#conversation_context" do
- context "with limited context" do
- before do
- @old_persona = playground.bot.persona
- persona = Fabricate(:ai_persona, max_context_posts: 1)
- playground.bot.persona = persona.class_instance.new
- end
-
- after { playground.bot.persona = @old_persona }
-
- it "respects max_context_post" do
- context = playground.conversation_context(third_post)
-
- expect(context).to contain_exactly(
- *[{ type: :user, id: user.username, content: third_post.raw }],
- )
- end
- end
-
- xit "includes previous posts ordered by post_number" do
- context = playground.conversation_context(third_post)
-
- expect(context).to contain_exactly(
- *[
- { type: :user, id: user.username, content: third_post.raw },
- { type: :model, content: second_post.raw },
- { type: :user, id: user.username, content: first_post.raw },
- ],
- )
- end
-
- xit "only include regular posts" do
- first_post.update!(post_type: Post.types[:whisper])
-
- context = playground.conversation_context(third_post)
-
- # skips leading model reply which makes no sense cause first post was whisper
- expect(context).to contain_exactly(
- *[{ type: :user, id: user.username, content: third_post.raw }],
- )
- end
-
- context "with custom prompts" do
- it "When post custom prompt is present, we use that instead of the post content" do
- custom_prompt = [
- [
- { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
- "time",
- "tool_call",
- ],
- [
- { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
- "time",
- "tool",
- ],
- ["I replied to the time command", bot_user.username],
- ]
-
- PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
-
- context = playground.conversation_context(third_post)
-
- expect(context).to contain_exactly(
- *[
- { type: :user, id: user.username, content: first_post.raw },
- { type: :tool_call, content: custom_prompt.first.first, id: "time" },
- { type: :tool, id: "time", content: custom_prompt.second.first },
- { type: :model, content: custom_prompt.third.first },
- { type: :user, id: user.username, content: third_post.raw },
- ],
- )
- end
- end
- end
end
diff --git a/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb
index 058e28d2..26d47528 100644
--- a/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/create_artifact_spec.rb
@@ -34,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::CreateArtifact do
{ html_body: "hello" },
bot_user: Fabricate(:user),
llm: llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
tool.parameters = { name: "hello", specification: "hello spec" }
diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
index 3eb1af62..dfb4e9d3 100644
--- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
@@ -15,9 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:progress_blk) { Proc.new {} }
- let(:dall_e) do
- described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
- end
+ let(:dall_e) { described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user) }
let(:base64_image) do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
@@ -30,8 +28,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
{ prompts: ["a cat"], aspect_ratio: "tall" },
llm: llm,
bot_user: bot_user,
- context: {
- },
)
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]
diff --git a/spec/lib/modules/ai_bot/tools/image_spec.rb b/spec/lib/modules/ai_bot/tools/image_spec.rb
index 6f14448f..aea98dd9 100644
--- a/spec/lib/modules/ai_bot/tools/image_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/image_spec.rb
@@ -9,8 +9,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
{ prompts: prompts, seeds: [99, 32] },
bot_user: bot_user,
llm: llm,
- context: {
- },
+ context: DiscourseAi::AiBot::BotContext.new,
)
end
diff --git a/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb
index 252e7eec..8247c528 100644
--- a/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/read_artifact_spec.rb
@@ -25,9 +25,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post2.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
result = tool.invoke {}
@@ -46,9 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "invalid-url" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
result = tool.invoke {}
@@ -62,9 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/99999" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
result = tool.invoke {}
@@ -97,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "https://example.com" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
result = tool.invoke {}
@@ -128,9 +120,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "https://example.com" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(post: post),
)
result = tool.invoke {}
diff --git a/spec/lib/modules/ai_bot/tools/read_spec.rb b/spec/lib/modules/ai_bot/tools/read_spec.rb
index d8ffe8ae..88dc7906 100644
--- a/spec/lib/modules/ai_bot/tools/read_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/read_spec.rb
@@ -56,9 +56,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
persona_options: {
"read_private" => true,
},
- context: {
- user: admin,
- },
+ context: DiscourseAi::AiBot::BotContext.new(user: admin),
)
results = tool.invoke
expect(results[:content]).to include("hello there")
@@ -68,9 +66,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
bot_user: bot_user,
llm: llm,
- context: {
- user: admin,
- },
+ context: DiscourseAi::AiBot::BotContext.new(user: admin),
)
results = tool.invoke
diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb
index 94dba3c7..c7be7522 100644
--- a/spec/lib/modules/ai_bot/tools/search_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/search_spec.rb
@@ -60,9 +60,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
persona_options: persona_options,
bot_user: bot_user,
llm: llm,
- context: {
- user: user,
- },
+ context: DiscourseAi::AiBot::BotContext.new(user: user),
)
expect(search.options[:base_query]).to eq("#funny")
diff --git a/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb b/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb
index c9834d25..f3d2a28f 100644
--- a/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/update_artifact_spec.rb
@@ -47,9 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: {
"update_algorithm" => "full",
},
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
result = tool.invoke {}
@@ -93,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: {
"update_algorithm" => "full",
},
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
result = tool.invoke {}
@@ -119,9 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Invalid update" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
result = tool.invoke {}
@@ -135,9 +129,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: -1, instructions: "Update something" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
result = tool.invoke {}
@@ -163,9 +155,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: {
"update_algorithm" => "full",
},
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
tool.invoke {}
@@ -196,9 +186,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: {
"update_algorithm" => "full",
},
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
.invoke {}
end
@@ -224,9 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: {
"update_algorithm" => "full",
},
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
)
result = tool.invoke {}
@@ -276,9 +262,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
persona_options: {
"update_algorithm" => "diff",
},
@@ -346,9 +330,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
bot_user: bot_user,
llm: llm_model.to_llm,
- context: {
- post_id: post.id,
- },
+ context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
persona_options: {
"update_algorithm" => "diff",
},
diff --git a/spec/lib/modules/ai_moderation/spam_scanner_spec.rb b/spec/lib/modules/ai_moderation/spam_scanner_spec.rb
index 129d8bdb..97cd4dfa 100644
--- a/spec/lib/modules/ai_moderation/spam_scanner_spec.rb
+++ b/spec/lib/modules/ai_moderation/spam_scanner_spec.rb
@@ -255,11 +255,12 @@ RSpec.describe DiscourseAi::AiModeration::SpamScanner do
prompt = _prompts.first
end
- content = prompt.messages[1][:content]
+ # its an array so lets just stringify it to make testing easier
+ content = prompt.messages[1][:content][0]
expect(content).to include(post.topic.title)
expect(content).to include(post.raw)
- upload_ids = prompt.messages[1][:upload_ids]
+ upload_ids = prompt.messages[1][:content].map { |m| m[:upload_id] if m.is_a?(Hash) }.compact
expect(upload_ids).to be_present
expect(upload_ids).to eq(post.upload_ids)
diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb
index c4b20bf3..e9c9979a 100644
--- a/spec/lib/modules/automation/llm_triage_spec.rb
+++ b/spec/lib/modules/automation/llm_triage_spec.rb
@@ -199,7 +199,7 @@ describe DiscourseAi::Automation::LlmTriage do
triage_prompt = DiscourseAi::Completions::Llm.prompts.last
- expect(triage_prompt.messages.last[:upload_ids]).to contain_exactly(post_upload.id)
+ expect(triage_prompt.messages.last[:content].last).to eq({ upload_id: post_upload.id })
end
end
diff --git a/spec/models/ai_tool_spec.rb b/spec/models/ai_tool_spec.rb
index 0df68dea..45ec3175 100644
--- a/spec/models/ai_tool_spec.rb
+++ b/spec/models/ai_tool_spec.rb
@@ -5,6 +5,7 @@ RSpec.describe AiTool do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
fab!(:topic)
fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") }
+ fab!(:bot_user) { Discourse.system_user }
def create_tool(
parameters: nil,
@@ -16,7 +17,8 @@ RSpec.describe AiTool do
name: "test #{SecureRandom.uuid}",
tool_name: "test_#{SecureRandom.uuid.underscore}",
description: "test",
- parameters: parameters || [{ name: "query", type: "string", desciption: "perform a search" }],
+ parameters:
+ parameters || [{ name: "query", type: "string", description: "perform a search" }],
script: script || "function invoke(params) { return params; }",
created_by_id: 1,
summary: "Test tool summary",
@@ -32,11 +34,11 @@ RSpec.describe AiTool do
{
name: tool.tool_name,
description: "test",
- parameters: [{ name: "query", type: "string", desciption: "perform a search" }],
+ parameters: [{ name: "query", type: "string", description: "perform a search" }],
},
)
- runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
expect(runner.invoke).to eq("query" => "test")
end
@@ -57,7 +59,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil)
stub_request(verb, "https://example.com/api").with(
body: "{\"data\":\"test data\"}",
@@ -83,7 +85,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/test").with(
headers: {
@@ -110,7 +112,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({}, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({}, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/").to_return(
status: 200,
@@ -134,7 +136,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/test").with(
headers: {
@@ -160,13 +162,16 @@ RSpec.describe AiTool do
}
JS
+ tool = create_tool(script: script)
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
+
stub_request(:get, "https://example.com/test").to_return do
sleep 0.01
{ status: 200, body: "Hello World", headers: {} }
end
tool = create_tool(script: script)
- runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
runner.timeout = 10
@@ -184,7 +189,7 @@ RSpec.describe AiTool do
tool = create_tool(script: script)
- runner = tool.runner({}, llm: llm, bot_user: nil, context: {})
+ runner = tool.runner({}, llm: llm, bot_user: nil)
result = runner.invoke
expect(result).to eq("Hello")
@@ -209,7 +214,7 @@ RSpec.describe AiTool do
responses = ["Hello ", "World"]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
- runner = tool.runner({}, llm: llm, bot_user: nil, context: {})
+ runner = tool.runner({}, llm: llm, bot_user: nil)
result = runner.invoke
prompts = _prompts
end
@@ -232,7 +237,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
runner.timeout = 5
@@ -295,7 +300,7 @@ RSpec.describe AiTool do
RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id])
- result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke
+ result = tool.runner({}, llm: nil, bot_user: nil).invoke
expected = [
[{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }],
@@ -316,7 +321,7 @@ RSpec.describe AiTool do
# this part of the API is a bit awkward, maybe we should do it
# automatically
RagDocumentFragment.update_target_uploads(tool, [upload1.id, upload2.id])
- result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke
+ result = tool.runner({}, llm: nil, bot_user: nil).invoke
expected = [
[{ "fragment" => "48 49 50", "metadata" => nil }],
@@ -340,7 +345,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil)
result = runner.invoke
@@ -364,7 +369,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil)
result = runner.invoke
post_hash = result["post"]
@@ -393,7 +398,7 @@ RSpec.describe AiTool do
JS
tool = create_tool(script: script)
- runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {})
+ runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil)
result = runner.invoke
@@ -401,4 +406,158 @@ RSpec.describe AiTool do
expect(result["rows"].first["title"]).to eq(topic.title)
end
end
+
+ context "when using the chat API" do
+ before(:each) do
+ skip "Chat plugin tests skipped because Chat module is not defined." unless defined?(Chat)
+ SiteSetting.chat_enabled = true
+ end
+
+ fab!(:chat_user) { Fabricate(:user) }
+ fab!(:chat_channel) do
+ Fabricate(:chat_channel).tap do |channel|
+ Fabricate(
+ :user_chat_channel_membership,
+ user: chat_user,
+ chat_channel: channel,
+ following: true,
+ )
+ end
+ end
+
+ it "can create a chat message" do
+ script = <<~JS
+ function invoke(params) {
+ return discourse.createChatMessage({
+ channel_name: params.channel_name,
+ username: params.username,
+ message: params.message
+ });
+ }
+ JS
+
+ tool = create_tool(script: script)
+ runner =
+ tool.runner(
+ {
+ "channel_name" => chat_channel.name,
+ "username" => chat_user.username,
+ "message" => "Hello from the tool!",
+ },
+ llm: nil,
+ bot_user: bot_user, # The user *running* the tool doesn't affect sender
+ )
+
+ initial_message_count = Chat::Message.count
+ result = runner.invoke
+
+ expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
+ expect(result["message"]).to eq("Hello from the tool!")
+ expect(result["created_at"]).to be_present
+ expect(result).not_to have_key("error")
+
+ # Verify message was actually created in the database
+ expect(Chat::Message.count).to eq(initial_message_count + 1)
+ created_message = Chat::Message.find_by(id: result["message_id"])
+
+ expect(created_message).not_to be_nil
+ expect(created_message.message).to eq("Hello from the tool!")
+ expect(created_message.user_id).to eq(chat_user.id) # Message is sent AS the specified user
+ expect(created_message.chat_channel_id).to eq(chat_channel.id)
+ end
+
+ it "can create a chat message using channel slug" do
+ chat_channel.update!(name: "My Test Channel", slug: "my-test-channel")
+ expect(chat_channel.slug).to eq("my-test-channel")
+
+ script = <<~JS
+ function invoke(params) {
+ return discourse.createChatMessage({
+ channel_name: params.channel_slug, // Using slug here
+ username: params.username,
+ message: params.message
+ });
+ }
+ JS
+
+ tool = create_tool(script: script)
+ runner =
+ tool.runner(
+ {
+ "channel_slug" => chat_channel.slug,
+ "username" => chat_user.username,
+ "message" => "Hello via slug!",
+ },
+ llm: nil,
+ bot_user: bot_user,
+ )
+
+ result = runner.invoke
+
+ expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
+ # see: https://github.com/rubyjs/mini_racer/issues/348
+ # expect(result["message_id"]).to be_a(Integer)
+
+ created_message = Chat::Message.find_by(id: result["message_id"])
+ expect(created_message).not_to be_nil
+ expect(created_message.message).to eq("Hello via slug!")
+ expect(created_message.chat_channel_id).to eq(chat_channel.id)
+ end
+
+ it "returns an error if the channel is not found" do
+ script = <<~JS
+ function invoke(params) {
+ return discourse.createChatMessage({
+ channel_name: "non_existent_channel",
+ username: params.username,
+ message: params.message
+ });
+ }
+ JS
+
+ tool = create_tool(script: script)
+ runner =
+ tool.runner(
+ { "username" => chat_user.username, "message" => "Test" },
+ llm: nil,
+ bot_user: bot_user,
+ )
+
+ initial_message_count = Chat::Message.count
+ expect { runner.invoke }.to raise_error(
+ MiniRacer::RuntimeError,
+ /Channel not found: non_existent_channel/,
+ )
+
+ expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
+ end
+
+ it "returns an error if the user is not found" do
+ script = <<~JS
+ function invoke(params) {
+ return discourse.createChatMessage({
+ channel_name: params.channel_name,
+ username: "non_existent_user",
+ message: params.message
+ });
+ }
+ JS
+
+ tool = create_tool(script: script)
+ runner =
+ tool.runner(
+ { "channel_name" => chat_channel.name, "message" => "Test" },
+ llm: nil,
+ bot_user: bot_user,
+ )
+
+ initial_message_count = Chat::Message.count
+ expect { runner.invoke }.to raise_error(
+ MiniRacer::RuntimeError,
+ /User not found: non_existent_user/,
+ )
+
+ expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
+ end
+ end
end