FEATURE: flexible image handling within messages (#1214)

* DEV: refactor bot internals

This introduces a proper object for bot context, this makes
it simpler to improve context management as we go cause we
have a nice object to work with

Starts refactoring allowing for a single message to have
multiple uploads throughout

* transplant method to message builder

* chipping away at inline uploads

* image support is improved but not fully fixed yet

partially working in anthropic, still got quite a few dialects to go

* open ai and claude are now working

* Gemini is now working as well

* fix nova

* more dialects...

* fix ollama

* fix specs

* update artifact fixed

* more tests

* spam scanner

* pass more specs

* bunch of specs improved

* more bug fixes.

* all the rest of the tests are working

* improve tests coverage and ensure custom tools are aware of new context object

* tests are working, but we need more tests

* resolve merge conflict

* new preamble and expanded specs on ai tool

* remove concept of "standalone tools"

This is no longer needed, we can set custom raw, tool details are injected into tool calls
This commit is contained in:
Sam 2025-04-01 02:39:07 +11:00 committed by GitHub
parent f3e78f0d80
commit 5b6d39a206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 1381 additions and 723 deletions

View File

@ -55,7 +55,7 @@ module DiscourseAi
# we need an llm so we have a tokenizer # we need an llm so we have a tokenizer
# but will do without if none is available # but will do without if none is available
llm = LlmModel.first&.to_llm llm = LlmModel.first&.to_llm
runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user, context: {}) runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user)
result = runner.invoke result = runner.invoke
if result.is_a?(Hash) && result[:error] if result.is_a?(Hash) && result[:error]

View File

@ -30,9 +30,13 @@ module Jobs
base = { query: query, model_used: llm_model.display_name } base = { query: query, model_used: llm_model.display_name }
bot.reply( context =
{ conversation_context: [{ type: :user, content: query }], skip_tool_details: true }, DiscourseAi::AiBot::BotContext.new(
) do |partial| messages: [{ type: :user, content: query }],
skip_tool_details: true,
)
bot.reply(context) do |partial|
streamed_reply << partial streamed_reply << partial
# Throttle updates. # Throttle updates.

View File

@ -35,7 +35,7 @@ class AiTool < ActiveRecord::Base
tool_name.presence || name tool_name.presence || name
end end
def runner(parameters, llm:, bot_user:, context: {}) def runner(parameters, llm:, bot_user:, context: nil)
DiscourseAi::AiBot::ToolRunner.new( DiscourseAi::AiBot::ToolRunner.new(
parameters: parameters, parameters: parameters,
llm: llm, llm: llm,
@ -59,86 +59,166 @@ class AiTool < ActiveRecord::Base
def self.preamble def self.preamble
<<~JS <<~JS
/** /**
* Tool API Quick Reference * Tool API Quick Reference
* *
* Entry Functions * Entry Functions
* *
* invoke(parameters): Main function. Receives parameters (Object). Must return a JSON-serializable value. * invoke(parameters): Main function. Receives parameters defined in the tool's signature (Object).
* Example: * Must return a JSON-serializable value (e.g., string, number, object, array).
* function invoke(parameters) { return "result"; } * Example:
* * function invoke(parameters) { return { result: "Data processed", input: parameters.query }; }
* details(): Optional. Returns a string describing the tool. *
* Example: * details(): Optional function. Returns a string (can include basic HTML) describing
* function details() { return "Tool description."; } * the tool's action after invocation, often using data from the invocation.
* * This is displayed in the chat interface.
* Provided Objects * Example:
* * let lastUrl;
* 1. http * function invoke(parameters) {
* http.get(url, options?): Performs an HTTP GET request. * lastUrl = parameters.url;
* Parameters: * // ... perform action ...
* url (string): The request URL. * return { success: true, content: "..." };
* options (Object, optional): * }
* headers (Object): Request headers. * function details() {
* Returns: * return `Browsed: <a href="${lastUrl}">${lastUrl}</a>`;
* { status: number, body: string } * }
* *
* http.post(url, options?): Performs an HTTP POST request. * Provided Objects & Functions
* Parameters: *
* url (string): The request URL. * 1. http
* options (Object, optional): * Performs HTTP requests. Max 20 requests per execution.
* headers (Object): Request headers. *
* body (string): Request body. * http.get(url, options?): Performs GET request.
* Returns: * Parameters:
* { status: number, body: string } * url (string): The request URL.
* * options (Object, optional):
* (also available: http.put, http.patch, http.delete) * headers (Object): Request headers (e.g., { "Authorization": "Bearer key" }).
* * Returns: { status: number, body: string }
* Note: Max 20 HTTP requests per execution. *
* * http.post(url, options?): Performs POST request.
* 2. llm * Parameters:
* llm.truncate(text, length): Truncates text to a specified token length. * url (string): The request URL.
* Parameters: * options (Object, optional):
* text (string): Text to truncate. * headers (Object): Request headers.
* length (number): Max tokens. * body (string | Object): Request body. If an object, it's stringified as JSON.
* Returns: * Returns: { status: number, body: string }
* Truncated string. *
* * http.put(url, options?): Performs PUT request (similar to POST).
* 3. index * http.patch(url, options?): Performs PATCH request (similar to POST).
* index.search(query, options?): Searches indexed documents. * http.delete(url, options?): Performs DELETE request (similar to GET/POST).
* Parameters: *
* query (string): Search query. * 2. llm
* options (Object, optional): * Interacts with the Language Model.
* filenames (Array): Limit search to specific files. *
* limit (number): Max fragments (up to 200). * llm.truncate(text, length): Truncates text to a specified token length based on the configured LLM's tokenizer.
* Returns: * Parameters:
* Array of { fragment: string, metadata: string } * text (string): Text to truncate.
* * length (number): Maximum number of tokens.
* 4. upload * Returns: string (truncated text)
* upload.create(filename, base_64_content): Uploads a file. *
* Parameters: * llm.generate(prompt): Generates text using the configured LLM associated with the tool runner.
* filename (string): Name of the file. * Parameters:
* base_64_content (string): Base64 encoded file content. * prompt (string | Object): The prompt. Can be a simple string or an object
* Returns: * like { messages: [{ type: "system", content: "..." }, { type: "user", content: "..." }] }.
* { id: number, short_url: string } * Returns: string (generated text)
* *
* 5. chain * 3. index
* chain.setCustomRaw(raw): Sets the body of the post and exist chain. * Searches attached RAG (Retrieval-Augmented Generation) documents linked to this tool.
* Parameters: *
* raw (string): raw content to add to post. * index.search(query, options?): Searches indexed document fragments.
* * Parameters:
* Constraints * query (string): The search query used for semantic search.
* * options (Object, optional):
* Execution Time: 2000ms * filenames (Array<string>): Filter search to fragments from specific uploaded filenames.
* Memory: 10MB * limit (number): Maximum number of fragments to return (default: 10, max: 200).
* HTTP Requests: 20 per execution * Returns: Array<{ fragment: string, metadata: string | null }> - Ordered by relevance.
* Exceeding limits will result in errors or termination. *
* * 4. upload
* Security * Handles file uploads within Discourse.
* *
* Sandboxed Environment: No access to system or global objects. * upload.create(filename, base_64_content): Uploads a file created by the tool, making it available in Discourse.
* No File System Access: Cannot read or write files. * Parameters:
*/ * filename (string): The desired name for the file (basename is used for security).
* base_64_content (string): Base64 encoded content of the file.
* Returns: { id: number, url: string, short_url: string } - Details of the created upload record.
*
* 5. chain
* Controls the execution flow.
*
* chain.setCustomRaw(raw): Sets the final raw content of the bot's post and immediately
* stops the tool execution chain. Useful for tools that directly
* generate the full response content (e.g., image generation tools attaching the image markdown).
* Parameters:
* raw (string): The raw Markdown content for the post.
* Returns: void
*
* 6. discourse
* Interacts with Discourse specific features. Access is generally performed as the SystemUser.
*
* discourse.search(params): Performs a Discourse search.
* Parameters:
* params (Object): Search parameters (e.g., { search_query: "keyword", with_private: true, max_results: 10 }).
* `with_private: true` searches across all posts visible to the SystemUser. `result_style: 'detailed'` is used by default.
* Returns: Object (Discourse search results structure, includes posts, topics, users etc.)
*
* discourse.getPost(post_id): Retrieves details for a specific post.
* Parameters:
* post_id (number): The ID of the post.
* Returns: Object (Post details including `raw`, nested `topic` object with ListableTopicSerializer structure) or null if not found/accessible.
*
* discourse.getTopic(topic_id): Retrieves details for a specific topic.
* Parameters:
* topic_id (number): The ID of the topic.
* Returns: Object (Topic details using ListableTopicSerializer structure) or null if not found/accessible.
*
* discourse.getUser(user_id_or_username): Retrieves details for a specific user.
* Parameters:
* user_id_or_username (number | string): The ID or username of the user.
* Returns: Object (User details using UserSerializer structure) or null if not found.
*
* discourse.getPersona(name): Gets an object representing another AI Persona configured on the site.
* Parameters:
* name (string): The name of the target persona.
* Returns: Object { respondTo: function(params) } or null if persona not found.
* respondTo(params): Instructs the target persona to generate a response within the current context (e.g., replying to the same post or chat message).
* Parameters:
* params (Object, optional): { instructions: string, whisper: boolean }
* Returns: { success: boolean, post_id?: number, post_number?: number, message_id?: number } or { error: string }
*
* discourse.createChatMessage(params): Creates a new message in a Discourse Chat channel.
* Parameters:
* params (Object): { channel_name: string, username: string, message: string }
* `channel_name` can be the channel name or slug.
* `username` specifies the user who should appear as the sender. The user must exist.
* The sending user must have permission to post in the channel.
* Returns: { success: boolean, message_id?: number, message?: string, created_at?: string } or { error: string }
*
* 7. context
* An object containing information about the environment where the tool is being run.
* Available properties depend on the invocation context, but may include:
* post_id (number): ID of the post triggering the tool (if in a Post context).
* topic_id (number): ID of the topic (if in a Post context).
* private_message (boolean): Whether the context is a private message (in Post context).
* message_id (number): ID of the chat message triggering the tool (if in Chat context).
* channel_id (number): ID of the chat channel (if in Chat context).
* user (Object): Details of the user invoking the tool/persona (structure may vary, often null or SystemUser details unless explicitly passed).
* participants (string): Comma-separated list of usernames in a PM (if applicable).
* // ... other potential context-specific properties added by the calling environment.
*
* Constraints
*
* Execution Time: 2000ms (default timeout in milliseconds) - This timer *pauses* during external HTTP requests or LLM calls initiated via `http.*` or `llm.generate`, but applies to the script's own processing time.
* Memory: 10MB (V8 heap limit)
* Stack Depth: 20 (Marshal stack depth limit for Ruby interop)
* HTTP Requests: 20 per execution
* Exceeding limits will result in errors or termination (e.g., timeout error, out-of-memory error, TooManyRequestsError).
*
* Security
*
* Sandboxed Environment: The script runs in a restricted V8 JavaScript environment (via MiniRacer).
* No direct access to browser or environment, browser globals (like `window` or `document`), or the host system's file system.
* Network requests are proxied through the Discourse backend, not made directly from the sandbox.
*/
JS JS
end end

View File

@ -75,9 +75,9 @@ module DiscourseAi
def force_tool_if_needed(prompt, context) def force_tool_if_needed(prompt, context)
return if prompt.tool_choice == :none return if prompt.tool_choice == :none
context[:chosen_tools] ||= [] context.chosen_tools ||= []
forced_tools = persona.force_tool_use.map { |tool| tool.name } forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } force_tool = forced_tools.find { |name| !context.chosen_tools.include?(name) }
if force_tool && persona.forced_tool_count > 0 if force_tool && persona.forced_tool_count > 0
user_turns = prompt.messages.select { |m| m[:type] == :user }.length user_turns = prompt.messages.select { |m| m[:type] == :user }.length
@ -85,7 +85,7 @@ module DiscourseAi
end end
if force_tool if force_tool
context[:chosen_tools] << force_tool context.chosen_tools << force_tool
prompt.tool_choice = force_tool prompt.tool_choice = force_tool
else else
prompt.tool_choice = nil prompt.tool_choice = nil
@ -93,6 +93,9 @@ module DiscourseAi
end end
def reply(context, &update_blk) def reply(context, &update_blk)
unless context.is_a?(BotContext)
raise ArgumentError, "context must be an instance of BotContext"
end
llm = DiscourseAi::Completions::Llm.proxy(model) llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm) prompt = persona.craft_prompt(context, llm: llm)
@ -100,7 +103,7 @@ module DiscourseAi
ongoing_chain = true ongoing_chain = true
raw_context = [] raw_context = []
user = context[:user] user = context.user
llm_kwargs = { user: user } llm_kwargs = { user: user }
llm_kwargs[:temperature] = persona.temperature if persona.temperature llm_kwargs[:temperature] = persona.temperature if persona.temperature
@ -277,27 +280,15 @@ module DiscourseAi
name: tool.name, name: tool.name,
} }
if tool.standalone? prompt.push(**tool_call_message)
standalone_context = prompt.push(**tool_message)
context.dup.merge(
conversation_context: [
context[:conversation_context].last,
tool_call_message,
tool_message,
],
)
prompt = persona.craft_prompt(standalone_context)
else
prompt.push(**tool_call_message)
prompt.push(**tool_message)
end
raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name] raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name]
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
end end
def invoke_tool(tool, llm, cancel, context, &update_blk) def invoke_tool(tool, llm, cancel, context, &update_blk)
show_placeholder = !context[:skip_tool_details] && !tool.class.allow_partial_tool_calls? show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder

107
lib/ai_bot/bot_context.rb Normal file
View File

@ -0,0 +1,107 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
class BotContext
attr_accessor :messages,
:topic_id,
:post_id,
:private_message,
:custom_instructions,
:user,
:skip_tool_details,
:participants,
:chosen_tools,
:message_id,
:channel_id,
:context_post_ids
def initialize(
post: nil,
participants: nil,
user: nil,
skip_tool_details: nil,
messages: [],
custom_instructions: nil,
site_url: nil,
site_title: nil,
site_description: nil,
time: nil,
message_id: nil,
channel_id: nil,
context_post_ids: nil
)
@participants = participants
@user = user
@skip_tool_details = skip_tool_details
@messages = messages
@custom_instructions = custom_instructions
@message_id = message_id
@channel_id = channel_id
@context_post_ids = context_post_ids
@site_url = site_url
@site_title = site_title
@site_description = site_description
@time = time
if post
@post_id = post.id
@topic_id = post.topic_id
@private_message = post.topic.private_message?
@participants ||= post.topic.allowed_users.map(&:username).join(", ") if @private_message
@user = post.user
end
end
# these are strings that can be safely interpolated into templates
TEMPLATE_PARAMS = %w[time site_url site_title site_description participants]
def lookup_template_param(key)
public_send(key.to_sym) if TEMPLATE_PARAMS.include?(key)
end
def time
@time ||= Time.zone.now
end
def site_url
@site_url ||= Discourse.base_url
end
def site_title
@site_title ||= SiteSetting.title
end
def site_description
@site_description ||= SiteSetting.site_description
end
def private_message?
@private_message
end
def to_json
{
messages: @messages,
topic_id: @topic_id,
post_id: @post_id,
private_message: @private_message,
custom_instructions: @custom_instructions,
username: @user&.username,
user_id: @user&.id,
participants: @participants,
chosen_tools: @chosen_tools,
message_id: @message_id,
channel_id: @channel_id,
context_post_ids: @context_post_ids,
site_url: @site_url,
site_title: @site_title,
site_description: @site_description,
skip_tool_details: @skip_tool_details,
}
end
end
end
end

View File

@ -163,7 +163,7 @@ module DiscourseAi
def craft_prompt(context, llm: nil) def craft_prompt(context, llm: nil)
system_insts = system_insts =
system_prompt.gsub(/\{(\w+)\}/) do |match| system_prompt.gsub(/\{(\w+)\}/) do |match|
found = context[match[1..-2].to_sym] found = context.lookup_template_param(match[1..-2])
found.nil? ? match : found.to_s found.nil? ? match : found.to_s
end end
@ -180,16 +180,16 @@ module DiscourseAi
) )
end end
if context[:custom_instructions].present? if context.custom_instructions.present?
prompt_insts << "\n" prompt_insts << "\n"
prompt_insts << context[:custom_instructions] prompt_insts << context.custom_instructions
end end
fragments_guidance = fragments_guidance =
rag_fragments_prompt( rag_fragments_prompt(
context[:conversation_context].to_a, context.messages,
llm: question_consolidator_llm, llm: question_consolidator_llm,
user: context[:user], user: context.user,
)&.strip )&.strip
prompt_insts << fragments_guidance if fragments_guidance.present? prompt_insts << fragments_guidance if fragments_guidance.present?
@ -197,9 +197,9 @@ module DiscourseAi
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
prompt_insts, prompt_insts,
messages: context[:conversation_context].to_a, messages: context.messages,
topic_id: context[:topic_id], topic_id: context.topic_id,
post_id: context[:post_id], post_id: context.post_id,
) )
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled

View File

@ -227,90 +227,17 @@ module DiscourseAi
schedule_bot_reply(post) if can_attach?(post) schedule_bot_reply(post) if can_attach?(post)
end end
def conversation_context(post, style: nil)
# Pay attention to the `post_number <= ?` here.
# We want to inject the last post as context because they are translated differently.
# also setting default to 40, allowing huge contexts costs lots of tokens
max_posts = 40
if bot.persona.class.respond_to?(:max_context_posts)
max_posts = bot.persona.class.max_context_posts || 40
end
post_types = [Post.types[:regular]]
post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
context =
post
.topic
.posts
.joins(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number)
.order("post_number desc")
.where("post_type in (?)", post_types)
.limit(max_posts)
.pluck(
"posts.raw",
"users.username",
"post_custom_prompts.custom_prompt",
"(
SELECT array_agg(ref.upload_id)
FROM upload_references ref
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
) as upload_ids",
)
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
builder.topic = post.topic
context.reverse_each do |raw, username, custom_prompt, upload_ids|
custom_prompt_translation =
Proc.new do |message|
# We can't keep backwards-compatibility for stored functions.
# Tool syntax requires a tool_call_id which we don't have.
if message[2] != "function"
custom_context = {
content: message[0],
type: message[2].present? ? message[2].to_sym : :model,
}
custom_context[:id] = message[1] if custom_context[:type] != :model
custom_context[:name] = message[3] if message[3]
thinking = message[4]
custom_context[:thinking] = thinking if thinking
builder.push(**custom_context)
end
end
if custom_prompt.present?
custom_prompt.each(&custom_prompt_translation)
else
context = {
content: raw,
type: (available_bot_usernames.include?(username) ? :model : :user),
}
context[:id] = username if context[:type] == :user
if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
context[:upload_ids] = upload_ids.compact
end
builder.push(**context)
end
end
builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
end
def title_playground(post, user) def title_playground(post, user)
context = conversation_context(post) messages =
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
post,
max_posts: 5,
bot_usernames: available_bot_usernames,
include_uploads: bot.persona.class.vision_enabled,
)
bot bot
.get_updated_title(context, post, user) .get_updated_title(messages, post, user)
.tap do |new_title| .tap do |new_title|
PostRevisor.new(post.topic.first_post, post.topic).revise!( PostRevisor.new(post.topic.first_post, post.topic).revise!(
bot.bot_user, bot.bot_user,
@ -326,83 +253,6 @@ module DiscourseAi
) )
end end
def chat_context(message, channel, persona_user, context_post_ids)
has_vision = bot.persona.class.vision_enabled
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
current_id = message.id
if !channel.direct_message_channel?
# we are interacting via mentions ... strip mention
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
end
messages = nil
max_messages = 40
if bot.persona.class.respond_to?(:max_context_posts)
max_messages = bot.persona.class.max_context_posts || 40
end
if !message.thread_id && channel.direct_message_channel?
messages = [message]
elsif !channel.direct_message_channel? && !message.thread_id
messages =
Chat::Message
.joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
.where(chat_channel_id: channel.id)
.where(
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
)
.order(id: :desc)
.limit(max_messages)
.to_a
.reverse
end
messages ||=
ChatSDK::Thread.last_messages(
thread_id: message.thread_id,
guardian: Discourse.system_user.guardian,
page_size: max_messages,
)
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
guardian = Guardian.new(message.user)
if context_post_ids
builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision)
end
messages.each do |m|
# restore stripped message
m.message = instruction_message if m.id == current_id && instruction_message
if available_bot_user_ids.include?(m.user_id)
builder.push(type: :model, content: m.message)
else
upload_ids = nil
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
mapped_message = m.message
thread_title = nil
thread_title = m.thread&.title if include_thread_titles && m.thread_id
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
builder.push(
type: :user,
content: mapped_message,
name: m.user.username,
upload_ids: upload_ids,
)
end
end
builder.to_a(
limit: max_messages,
style: channel.direct_message_channel? ? :chat_with_context : :chat,
)
end
def reply_to_chat_message(message, channel, context_post_ids) def reply_to_chat_message(message, channel, context_post_ids)
persona_user = User.find(bot.persona.class.user_id) persona_user = User.find(bot.persona.class.user_id)
@ -410,10 +260,32 @@ module DiscourseAi
context_post_ids = nil if !channel.direct_message_channel? context_post_ids = nil if !channel.direct_message_channel?
max_chat_messages = 40
if bot.persona.class.respond_to?(:max_context_posts)
max_chat_messages = bot.persona.class.max_context_posts || 40
end
if !channel.direct_message_channel?
# we are interacting via mentions ... strip mention
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
end
context = context =
get_context( BotContext.new(
participants: participants.join(", "), participants: participants,
conversation_context: chat_context(message, channel, persona_user, context_post_ids), message_id: message.id,
channel_id: channel.id,
context_post_ids: context_post_ids,
messages:
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_chat(
message,
channel: channel,
context_post_ids: context_post_ids,
include_uploads: bot.persona.class.vision_enabled,
max_messages: max_chat_messages,
bot_user_ids: available_bot_user_ids,
instruction_message: instruction_message,
),
user: message.user, user: message.user,
skip_tool_details: true, skip_tool_details: true,
) )
@ -460,22 +332,6 @@ module DiscourseAi
reply reply
end end
def get_context(participants:, conversation_context:, user:, skip_tool_details: nil)
result = {
site_url: Discourse.base_url,
site_title: SiteSetting.title,
site_description: SiteSetting.site_description,
time: Time.zone.now,
participants: participants,
conversation_context: conversation_context,
user: user,
}
result[:skip_tool_details] = true if skip_tool_details
result
end
def reply_to( def reply_to(
post, post,
custom_instructions: nil, custom_instructions: nil,
@ -509,16 +365,25 @@ module DiscourseAi
end end
) )
# safeguard
max_context_posts = 40
if bot.persona.class.respond_to?(:max_context_posts)
max_context_posts = bot.persona.class.max_context_posts || 40
end
context = context =
get_context( BotContext.new(
participants: post.topic.allowed_users.map(&:username).join(", "), post: post,
conversation_context: conversation_context(post, style: context_style), custom_instructions: custom_instructions,
user: post.user, messages:
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
post,
style: context_style,
max_posts: max_context_posts,
include_uploads: bot.persona.class.vision_enabled,
bot_usernames: available_bot_usernames,
),
) )
context[:post_id] = post.id
context[:topic_id] = post.topic_id
context[:private_message] = post.topic.private_message?
context[:custom_instructions] = custom_instructions
reply_user = bot.bot_user reply_user = bot.bot_user
if bot.persona.class.respond_to?(:user_id) if bot.persona.class.respond_to?(:user_id)
@ -562,7 +427,7 @@ module DiscourseAi
Discourse.redis.setex(redis_stream_key, 60, 1) Discourse.redis.setex(redis_stream_key, 60, 1)
end end
context[:skip_tool_details] ||= !bot.persona.class.tool_details context.skip_tool_details ||= !bot.persona.class.tool_details
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply

View File

@ -13,7 +13,13 @@ module DiscourseAi
MARSHAL_STACK_DEPTH = 20 MARSHAL_STACK_DEPTH = 20
MAX_HTTP_REQUESTS = 20 MAX_HTTP_REQUESTS = 20
def initialize(parameters:, llm:, bot_user:, context: {}, tool:, timeout: nil) def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
if context && !context.is_a?(DiscourseAi::AiBot::BotContext)
raise ArgumentError, "context must be a BotContext object"
end
context ||= DiscourseAi::AiBot::BotContext.new
@parameters = parameters @parameters = parameters
@llm = llm @llm = llm
@bot_user = bot_user @bot_user = bot_user
@ -99,7 +105,7 @@ module DiscourseAi
}, },
}; };
const context = #{JSON.generate(@context)}; const context = #{JSON.generate(@context.to_json)};
function details() { return ""; }; function details() { return ""; };
JS JS
@ -240,13 +246,13 @@ module DiscourseAi
def llm_user def llm_user
@llm_user ||= @llm_user ||=
begin begin
@context[:llm_user] || post&.user || @bot_user post&.user || @bot_user
end end
end end
def post def post
return @post if defined?(@post) return @post if defined?(@post)
post_id = @context[:post_id] post_id = @context.post_id
@post = post_id && Post.find_by(id: post_id) @post = post_id && Post.find_by(id: post_id)
end end
@ -336,8 +342,8 @@ module DiscourseAi
bot = DiscourseAi::AiBot::Bot.as(@bot_user || persona.user, persona: persona) bot = DiscourseAi::AiBot::Bot.as(@bot_user || persona.user, persona: persona)
playground = DiscourseAi::AiBot::Playground.new(bot) playground = DiscourseAi::AiBot::Playground.new(bot)
if @context[:post_id] if @context.post_id
post = Post.find_by(id: @context[:post_id]) post = Post.find_by(id: @context.post_id)
return { error: "Post not found" } if post.nil? return { error: "Post not found" } if post.nil?
reply_post = reply_post =
@ -354,13 +360,13 @@ module DiscourseAi
else else
return { error: "Failed to create reply" } return { error: "Failed to create reply" }
end end
elsif @context[:message_id] && @context[:channel_id] elsif @context.message_id && @context.channel_id
message = Chat::Message.find_by(id: @context[:message_id]) message = Chat::Message.find_by(id: @context.message_id)
channel = Chat::Channel.find_by(id: @context[:channel_id]) channel = Chat::Channel.find_by(id: @context.channel_id)
return { error: "Message or channel not found" } if message.nil? || channel.nil? return { error: "Message or channel not found" } if message.nil? || channel.nil?
reply = reply =
playground.reply_to_chat_message(message, channel, @context[:context_post_ids]) playground.reply_to_chat_message(message, channel, @context.context_post_ids)
if reply if reply
return { success: true, message_id: reply.id } return { success: true, message_id: reply.id }
@ -457,7 +463,7 @@ module DiscourseAi
UploadCreator.new( UploadCreator.new(
file, file,
filename, filename,
for_private_message: @context[:private_message], for_private_message: @context.private_message,
).create_for(@bot_user.id) ).create_for(@bot_user.id)
{ id: upload.id, short_url: upload.short_url, url: upload.url } { id: upload.id, short_url: upload.short_url, url: upload.url }

View File

@ -108,7 +108,7 @@ module DiscourseAi
end end
def invoke def invoke
post = Post.find_by(id: context[:post_id]) post = Post.find_by(id: context.post_id)
return error_response("No post context found") unless post return error_response("No post context found") unless post
partial_response = +"" partial_response = +""

View File

@ -111,7 +111,7 @@ module DiscourseAi
UploadCreator.new( UploadCreator.new(
file, file,
"image.png", "image.png",
for_private_message: context[:private_message], for_private_message: context.private_message?,
).create_for(bot_user.id), ).create_for(bot_user.id),
} }
end end

View File

@ -131,7 +131,7 @@ module DiscourseAi
UploadCreator.new( UploadCreator.new(
file, file,
"image.png", "image.png",
for_private_message: context[:private_message], for_private_message: context.private_message,
).create_for(bot_user.id), ).create_for(bot_user.id),
seed: image[:seed], seed: image[:seed],
} }

View File

@ -48,7 +48,7 @@ module DiscourseAi
def invoke def invoke
not_found = { topic_id: topic_id, description: "Topic not found" } not_found = { topic_id: topic_id, description: "Topic not found" }
guardian = Guardian.new(context[:user]) if options[:read_private] && context[:user] guardian = Guardian.new(context.user) if options[:read_private] && context.user
guardian ||= Guardian.new guardian ||= Guardian.new
@title = "" @title = ""

View File

@ -59,7 +59,7 @@ module DiscourseAi
end end
def post def post
@post ||= Post.find_by(id: context[:post_id]) @post ||= Post.find_by(id: context.post_id)
end end
def handle_discourse_artifact(uri) def handle_discourse_artifact(uri)

View File

@ -130,7 +130,7 @@ module DiscourseAi
after: parameters[:after], after: parameters[:after],
status: parameters[:status], status: parameters[:status],
max_results: max_results, max_results: max_results,
current_user: options[:search_private] ? context[:user] : nil, current_user: options[:search_private] ? context.user : nil,
) )
@last_num_results = results[:rows]&.length || 0 @last_num_results = results[:rows]&.length || 0

View File

@ -40,10 +40,6 @@ module DiscourseAi
false false
end end
def standalone?
true
end
def custom_raw def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found") @last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end end

View File

@ -56,14 +56,17 @@ module DiscourseAi
persona_options: {}, persona_options: {},
bot_user:, bot_user:,
llm:, llm:,
context: {} context: nil
) )
@parameters = parameters @parameters = parameters
@tool_call_id = tool_call_id @tool_call_id = tool_call_id
@persona_options = persona_options @persona_options = persona_options
@bot_user = bot_user @bot_user = bot_user
@llm = llm @llm = llm
@context = context @context = context.nil? ? DiscourseAi::AiBot::BotContext.new(messages: []) : context
if !@context.is_a?(DiscourseAi::AiBot::BotContext)
raise ArgumentError, "context must be a DiscourseAi::AiBot::Context"
end
end end
def name def name
@ -108,10 +111,6 @@ module DiscourseAi
true true
end end
def standalone?
false
end
protected protected
def fetch_default_branch(repo) def fetch_default_branch(repo)

View File

@ -39,7 +39,7 @@ module DiscourseAi
def self.inject_prompt(prompt:, context:, persona:) def self.inject_prompt(prompt:, context:, persona:)
return if persona.options["do_not_echo_artifact"].to_s == "true" return if persona.options["do_not_echo_artifact"].to_s == "true"
# we inject the current artifact content into the last user message # we inject the current artifact content into the last user message
if topic_id = context[:topic_id] if topic_id = context.topic_id
posts = Post.where(topic_id: topic_id) posts = Post.where(topic_id: topic_id)
artifact = AiArtifact.order("id desc").where(post: posts).first artifact = AiArtifact.order("id desc").where(post: posts).first
if artifact if artifact
@ -113,7 +113,7 @@ module DiscourseAi
end end
def invoke def invoke
post = Post.find_by(id: context[:post_id]) post = Post.find_by(id: context.post_id)
return error_response("No post context found") unless post return error_response("No post context found") unless post
artifact = AiArtifact.find_by(id: parameters[:artifact_id]) artifact = AiArtifact.find_by(id: parameters[:artifact_id])

View File

@ -188,9 +188,10 @@ module DiscourseAi
messages: [ messages: [
{ {
type: :user, type: :user,
content: content: [
"Describe this image in a single sentence#{custom_locale_instructions(user)}", "Describe this image in a single sentence#{custom_locale_instructions(user)}",
upload_ids: [upload.id], { upload_id: upload.id },
],
}, },
], ],
) )

View File

@ -187,7 +187,10 @@ module DiscourseAi
prompt = DiscourseAi::Completions::Prompt.new(system_prompt) prompt = DiscourseAi::Completions::Prompt.new(system_prompt)
args = { type: :user, content: context } args = { type: :user, content: context }
upload_ids = post.upload_ids upload_ids = post.upload_ids
args[:upload_ids] = upload_ids.take(3) if upload_ids.present? if upload_ids.present?
args[:content] = [args[:content]]
upload_ids.take(3).each { |upload_id| args[:content] << { upload_id: upload_id } }
end
prompt.push(**args) prompt.push(**args)
prompt prompt
end end

View File

@ -7,11 +7,7 @@ module DiscourseAi
return if !tool return if !tool
return if !tool.parameters.blank? return if !tool.parameters.blank?
context = { context = DiscourseAi::AiBot::BotContext.new(post: post)
post_id: post.id,
automation_id: automation&.id,
automation_name: automation&.name,
}
runner = tool.runner({}, llm: nil, bot_user: Discourse.system_user, context: context) runner = tool.runner({}, llm: nil, bot_user: Discourse.system_user, context: context)
runner.invoke runner.invoke

View File

@ -42,7 +42,12 @@ module DiscourseAi
content = llm.tokenizer.truncate(content, max_post_tokens) if max_post_tokens.present? content = llm.tokenizer.truncate(content, max_post_tokens) if max_post_tokens.present?
prompt.push(type: :user, content: content, upload_ids: post.upload_ids) if post.upload_ids.present?
content = [content]
content.concat(post.upload_ids.map { |upload_id| { upload_id: upload_id } })
end
prompt.push(type: :user, content: content)
result = nil result = nil

View File

@ -17,13 +17,13 @@ module DiscourseAi
llm_model.provider == "open_ai" || llm_model.provider == "azure" llm_model.provider == "open_ai" || llm_model.provider == "azure"
end end
def translate def embed_user_ids?
return @embed_user_ids if defined?(@embed_user_ids)
@embed_user_ids = @embed_user_ids =
prompt.messages.any? do |m| prompt.messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end end
super
end end
def max_prompt_tokens def max_prompt_tokens
@ -102,35 +102,47 @@ module DiscourseAi
end end
def user_msg(msg) def user_msg(msg)
user_message = { role: "user", content: msg[:content] } content_array = []
user_message = { role: "user" }
if msg[:id] if msg[:id]
if @embed_user_ids if embed_user_ids?
user_message[:content] = "#{msg[:id]}: #{msg[:content]}" content_array << "#{msg[:id]}: "
else else
user_message[:name] = msg[:id] user_message[:name] = msg[:id]
end end
end end
user_message[:content] = inline_images(user_message[:content], msg) if vision_support? content_array << msg[:content]
content_array =
to_encoded_content_array(
content: content_array.flatten,
image_encoder: ->(details) { image_node(details) },
text_encoder: ->(text) { { type: "text", text: text } },
allow_vision: vision_support?,
)
user_message[:content] = no_array_if_only_text(content_array)
user_message user_message
end end
def inline_images(content, message) def no_array_if_only_text(content_array)
encoded_uploads = prompt.encoded_uploads(message) if content_array.size == 1 && content_array.first[:type] == "text"
return content if encoded_uploads.blank? content_array.first[:text]
else
content_array
end
end
content_w_imgs = def image_node(details)
encoded_uploads.reduce([]) do |memo, details| {
memo << { type: "image_url",
type: "image_url", image_url: {
image_url: { url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
url: "data:#{details[:mime_type]};base64,#{details[:base64]}", },
}, }
}
end
content_w_imgs << { type: "text", text: message[:content] }
end end
def per_message_overhead def per_message_overhead

View File

@ -87,9 +87,9 @@ module DiscourseAi
end end
def model_msg(msg) def model_msg(msg)
if msg[:thinking] || msg[:redacted_thinking_signature] content_array = []
content_array = []
if msg[:thinking] || msg[:redacted_thinking_signature]
if msg[:thinking] if msg[:thinking]
content_array << { content_array << {
type: "thinking", type: "thinking",
@ -104,13 +104,19 @@ module DiscourseAi
data: msg[:redacted_thinking_signature], data: msg[:redacted_thinking_signature],
} }
end end
content_array << { type: "text", text: msg[:content] }
{ role: "assistant", content: content_array }
else
{ role: "assistant", content: msg[:content] }
end end
# other encoder is used to pass through thinking
content_array =
to_encoded_content_array(
content: [content_array, msg[:content]].flatten,
image_encoder: ->(details) {},
text_encoder: ->(text) { { type: "text", text: text } },
other_encoder: ->(details) { details },
allow_vision: false,
)
{ role: "assistant", content: no_array_if_only_text(content_array) }
end end
def system_msg(msg) def system_msg(msg)
@ -124,31 +130,39 @@ module DiscourseAi
end end
def user_msg(msg) def user_msg(msg)
content = +"" content_array = []
content << "#{msg[:id]}: " if msg[:id] content_array << "#{msg[:id]}: " if msg[:id]
content << msg[:content] content_array.concat([msg[:content]].flatten)
content = inline_images(content, msg) if vision_support?
{ role: "user", content: content } content_array =
to_encoded_content_array(
content: content_array,
image_encoder: ->(details) { image_node(details) },
text_encoder: ->(text) { { type: "text", text: text } },
allow_vision: vision_support?,
)
{ role: "user", content: no_array_if_only_text(content_array) }
end end
def inline_images(content, message) # keeping our payload as backward compatible as possible
encoded_uploads = prompt.encoded_uploads(message) def no_array_if_only_text(content_array)
return content if encoded_uploads.blank? if content_array.length == 1 && content_array.first[:type] == "text"
content_array.first[:text]
else
content_array
end
end
content_w_imgs = def image_node(details)
encoded_uploads.reduce([]) do |memo, details| {
memo << { source: {
source: { type: "base64",
type: "base64", data: details[:base64],
data: details[:base64], media_type: details[:mime_type],
media_type: details[:mime_type], },
}, type: "image",
type: "image", }
}
end
content_w_imgs << { type: "text", text: content }
end end
end end
end end

View File

@ -110,9 +110,9 @@ module DiscourseAi
end end
def user_msg(msg) def user_msg(msg)
user_message = { role: "USER", message: msg[:content] } content = prompt.text_only(msg)
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] user_message = { role: "USER", message: content }
user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id]
user_message user_message
end end
end end

View File

@ -227,6 +227,38 @@ module DiscourseAi
msg = msg.merge(content: new_content) msg = msg.merge(content: new_content)
user_msg(msg) user_msg(msg)
end end
def to_encoded_content_array(
content:,
image_encoder:,
text_encoder:,
other_encoder: nil,
allow_vision:
)
content = [content] if !content.is_a?(Array)
current_string = +""
result = []
content.each do |c|
if c.is_a?(String)
current_string << c
elsif c.is_a?(Hash) && c.key?(:upload_id) && allow_vision
if !current_string.empty?
result << text_encoder.call(current_string)
current_string = +""
end
encoded = prompt.encode_upload(c[:upload_id])
result << image_encoder.call(encoded) if encoded
elsif other_encoder
encoded = other_encoder.call(c)
result << encoded if encoded
end
end
result << text_encoder.call(current_string) if !current_string.empty?
result
end
end end
end end
end end

View File

@ -106,28 +106,29 @@ module DiscourseAi
end end
def user_msg(msg) def user_msg(msg)
if beta_api? content_array = []
# support new format with multiple parts content_array << "#{msg[:id]}: " if msg[:id]
result = { role: "user", parts: [{ text: msg[:content] }] }
return result unless vision_support?
upload_parts = uploaded_parts(msg) content_array << msg[:content]
result[:parts].concat(upload_parts) if upload_parts.present? content_array.flatten!
result
content_array =
to_encoded_content_array(
content: content_array,
image_encoder: ->(details) { image_node(details) },
text_encoder: ->(text) { { text: text } },
allow_vision: vision_support? && beta_api?,
)
if beta_api?
{ role: "user", parts: content_array }
else else
{ role: "user", parts: { text: msg[:content] } } { role: "user", parts: content_array.first }
end end
end end
def uploaded_parts(message) def image_node(details)
encoded_uploads = prompt.encoded_uploads(message) { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
result = []
if encoded_uploads.present?
encoded_uploads.each do |details|
result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
end
end
result
end end
def tool_call_msg(msg) def tool_call_msg(msg)

View File

@ -155,7 +155,7 @@ module DiscourseAi
end end
end end
{ role: "user", content: msg[:content], images: images } { role: "user", content: prompt.text_only(msg), images: images }
end end
def model_msg(msg) def model_msg(msg)

View File

@ -69,7 +69,7 @@ module DiscourseAi
end end
def user_msg(msg) def user_msg(msg)
user_message = { role: "user", content: msg[:content] } user_message = { role: "user", content: prompt.text_only(msg) }
encoded_uploads = prompt.encoded_uploads(msg) encoded_uploads = prompt.encoded_uploads(msg)
if encoded_uploads.present? if encoded_uploads.present?

View File

@ -3,7 +3,7 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class OpenAiCompatible < Dialect class OpenAiCompatible < ChatGpt
class << self class << self
def can_translate?(_llm_model) def can_translate?(_llm_model)
# fallback dialect # fallback dialect
@ -43,58 +43,6 @@ module DiscourseAi
translated.unshift(user_msg) translated.unshift(user_msg)
end end
private
def system_msg(msg)
msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
translated = tools_dialect.from_raw_tool_call(msg)
{ role: "assistant", content: translated }
end
def tool_msg(msg)
translated = tools_dialect.from_raw_tool(msg)
{ role: "user", content: translated }
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
message = { role: "user", content: content }
message[:content] = inline_images(message[:content], msg) if vision_support?
message
end
def inline_images(content, message)
encoded_uploads = prompt.encoded_uploads(message)
return content if encoded_uploads.blank?
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
memo << {
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end
end
end end
end end
end end

View File

@ -89,7 +89,6 @@ module DiscourseAi
content:, content:,
id: nil, id: nil,
name: nil, name: nil,
upload_ids: nil,
thinking: nil, thinking: nil,
thinking_signature: nil, thinking_signature: nil,
redacted_thinking_signature: nil redacted_thinking_signature: nil
@ -98,7 +97,6 @@ module DiscourseAi
new_message = { type: type, content: content } new_message = { type: type, content: content }
new_message[:name] = name.to_s if name new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id new_message[:id] = id.to_s if id
new_message[:upload_ids] = upload_ids if upload_ids
new_message[:thinking] = thinking if thinking new_message[:thinking] = thinking if thinking
new_message[:thinking_signature] = thinking_signature if thinking_signature new_message[:thinking_signature] = thinking_signature if thinking_signature
new_message[ new_message[
@ -115,11 +113,44 @@ module DiscourseAi
tools.present? tools.present?
end end
# helper method to get base64 encoded uploads
# at the correct dimentions
def encoded_uploads(message) def encoded_uploads(message)
return [] if message[:upload_ids].blank? if message[:content].is_a?(Array)
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels) upload_ids =
message[:content]
.map do |content|
content[:upload_id] if content.is_a?(Hash) && content.key?(:upload_id)
end
.compact
if !upload_ids.empty?
return UploadEncoder.encode(upload_ids: upload_ids, max_pixels: max_pixels)
end
end
[]
end
def text_only(message)
if message[:content].is_a?(Array)
message[:content].map { |element| element if element.is_a?(String) }.compact.join
else
message[:content]
end
end
def encode_upload(upload_id)
UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first
end
def content_with_encoded_uploads(content)
return [content] unless content.is_a?(Array)
content.map do |c|
if c.is_a?(Hash) && c.key?(:upload_id)
encode_upload(c[:upload_id])
else
c
end
end
end end
def ==(other) def ==(other)
@ -150,7 +181,6 @@ module DiscourseAi
content content
id id
name name
upload_ids
thinking thinking
thinking_signature thinking_signature
redacted_thinking_signature redacted_thinking_signature
@ -159,15 +189,17 @@ module DiscourseAi
raise ArgumentError, "message contains invalid keys: #{invalid_keys}" raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end end
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array) if message[:content].is_a?(Array)
raise ArgumentError, "upload_ids must be an array of ids" message[:content].each do |content|
if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id])
raise ArgumentError, "Array message content must be a string or {upload_id: ...} "
end
end
else
if !message[:content].is_a?(String)
raise ArgumentError, "Message content must be a string or an array"
end
end end
if message[:upload_ids].present? && message[:type] != :user
raise ArgumentError, "upload_ids are only supported for users"
end
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
end end
def validate_turn(last_turn, new_turn) def validate_turn(last_turn, new_turn)

View File

@ -9,6 +9,154 @@ module DiscourseAi
attr_reader :chat_context_post_upload_ids attr_reader :chat_context_post_upload_ids
attr_accessor :topic attr_accessor :topic
def self.messages_from_chat(
message,
channel:,
context_post_ids:,
max_messages:,
include_uploads:,
bot_user_ids:,
instruction_message: nil
)
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
current_id = message.id
messages = nil
if !message.thread_id && channel.direct_message_channel?
messages = [message]
elsif !channel.direct_message_channel? && !message.thread_id
messages =
Chat::Message
.joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
.where(chat_channel_id: channel.id)
.where(
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
)
.order(id: :desc)
.limit(max_messages)
.to_a
.reverse
end
messages ||=
ChatSDK::Thread.last_messages(
thread_id: message.thread_id,
guardian: Discourse.system_user.guardian,
page_size: max_messages,
)
builder = new
guardian = Guardian.new(message.user)
if context_post_ids
builder.set_chat_context_posts(
context_post_ids,
guardian,
include_uploads: include_uploads,
)
end
messages.each do |m|
# restore stripped message
m.message = instruction_message if m.id == current_id && instruction_message
if bot_user_ids.include?(m.user_id)
builder.push(type: :model, content: m.message)
else
upload_ids = nil
upload_ids = m.uploads.map(&:id) if include_uploads && m.uploads.present?
mapped_message = m.message
thread_title = nil
thread_title = m.thread&.title if include_thread_titles && m.thread_id
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
builder.push(
type: :user,
content: mapped_message,
name: m.user.username,
upload_ids: upload_ids,
)
end
end
builder.to_a(
limit: max_messages,
style: channel.direct_message_channel? ? :chat_with_context : :chat,
)
end
def self.messages_from_post(post, style: nil, max_posts:, bot_usernames:, include_uploads:)
# Pay attention to the `post_number <= ?` here.
# We want to inject the last post as context because they are translated differently.
post_types = [Post.types[:regular]]
post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
context =
post
.topic
.posts
.joins(:user)
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
.where("post_number <= ?", post.post_number)
.order("post_number desc")
.where("post_type in (?)", post_types)
.limit(max_posts)
.pluck(
"posts.raw",
"users.username",
"post_custom_prompts.custom_prompt",
"(
SELECT array_agg(ref.upload_id)
FROM upload_references ref
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
) as upload_ids",
)
builder = new
builder.topic = post.topic
context.reverse_each do |raw, username, custom_prompt, upload_ids|
custom_prompt_translation =
Proc.new do |message|
# We can't keep backwards-compatibility for stored functions.
# Tool syntax requires a tool_call_id which we don't have.
if message[2] != "function"
custom_context = {
content: message[0],
type: message[2].present? ? message[2].to_sym : :model,
}
custom_context[:id] = message[1] if custom_context[:type] != :model
custom_context[:name] = message[3] if message[3]
thinking = message[4]
custom_context[:thinking] = thinking if thinking
builder.push(**custom_context)
end
end
if custom_prompt.present?
custom_prompt.each(&custom_prompt_translation)
else
context = { content: raw, type: (bot_usernames.include?(username) ? :model : :user) }
context[:id] = username if context[:type] == :user
if upload_ids.present? && context[:type] == :user && include_uploads
context[:upload_ids] = upload_ids.compact
end
builder.push(**context)
end
end
builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
end
def initialize def initialize
@raw_messages = [] @raw_messages = []
end end
@ -68,12 +216,19 @@ module DiscourseAi
if message[:type] == :user if message[:type] == :user
old_name = last_message.delete(:name) old_name = last_message.delete(:name)
last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name last_message[:content] = ["#{old_name}: ", last_message[:content]].flatten if old_name
new_content = message[:content] new_content = message[:content]
new_content = "#{message[:name]}: #{new_content}" if message[:name] new_content = ["#{message[:name]}: ", new_content].flatten if message[:name]
last_message[:content] += "\n#{new_content}" if !last_message[:content].is_a?(Array)
last_message[:content] = [last_message[:content]]
end
last_message[:content].concat(["\n", new_content].flatten)
compressed =
compress_messages_buffer(last_message[:content], max_uploads: MAX_TOPIC_UPLOADS)
last_message[:content] = compressed
else else
last_message[:content] = message[:content] last_message[:content] = message[:content]
end end
@ -111,9 +266,9 @@ module DiscourseAi
end end
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array) raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids
message = { type: type, content: content } message = { type: type, content: content }
message[:name] = name.to_s if name message[:name] = name.to_s if name
message[:upload_ids] = upload_ids if upload_ids
message[:id] = id.to_s if id message[:id] = id.to_s if id
if thinking if thinking
message[:thinking] = thinking["thinking"] if thinking["thinking"] message[:thinking] = thinking["thinking"] if thinking["thinking"]
@ -132,67 +287,62 @@ module DiscourseAi
def topic_array def topic_array
raw_messages = @raw_messages.dup raw_messages = @raw_messages.dup
user_content = +"You are operating in a Discourse forum.\n\n" content_array = []
content_array << "You are operating in a Discourse forum.\n\n"
if @topic if @topic
if @topic.private_message? if @topic.private_message?
user_content << "Private message info.\n" content_array << "Private message info.\n"
else else
user_content << "Topic information:\n" content_array << "Topic information:\n"
end end
user_content << "- URL: #{@topic.url}\n" content_array << "- URL: #{@topic.url}\n"
user_content << "- Title: #{@topic.title}\n" content_array << "- Title: #{@topic.title}\n"
if SiteSetting.tagging_enabled if SiteSetting.tagging_enabled
tags = @topic.tags.pluck(:name) tags = @topic.tags.pluck(:name)
tags -= DiscourseTagging.hidden_tag_names if tags.present? tags -= DiscourseTagging.hidden_tag_names if tags.present?
user_content << "- Tags: #{tags.join(", ")}\n" if tags.present? content_array << "- Tags: #{tags.join(", ")}\n" if tags.present?
end end
if !@topic.private_message? if !@topic.private_message?
user_content << "- Category: #{@topic.category.name}\n" if @topic.category content_array << "- Category: #{@topic.category.name}\n" if @topic.category
end end
user_content << "- Number of replies: #{@topic.posts_count - 1}\n\n" content_array << "- Number of replies: #{@topic.posts_count - 1}\n\n"
end end
last_user_message = raw_messages.pop last_user_message = raw_messages.pop
upload_ids = []
if raw_messages.present? if raw_messages.present?
user_content << "Here is the conversation so far:\n" content_array << "Here is the conversation so far:\n"
raw_messages.each do |message| raw_messages.each do |message|
user_content << "#{message[:name] || "User"}: #{message[:content]}\n" content_array << "#{message[:name] || "User"}: "
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present? content_array << message[:content]
content_array << "\n\n"
end end
end end
if last_user_message if last_user_message
user_content << "You are responding to #{last_user_message[:name] || "User"} who just said:\n #{last_user_message[:content]}" content_array << "You are responding to #{last_user_message[:name] || "User"} who just said:\n"
if last_user_message[:upload_ids].present? content_array << last_user_message[:content]
upload_ids.concat(last_user_message[:upload_ids])
end
end end
user_message = { type: :user, content: user_content } content_array =
compress_messages_buffer(content_array.flatten, max_uploads: MAX_TOPIC_UPLOADS)
if upload_ids.present? user_message = { type: :user, content: content_array }
user_message[:upload_ids] = upload_ids[-MAX_TOPIC_UPLOADS..-1] || upload_ids
end
[user_message] [user_message]
end end
def chat_array(limit:) def chat_array(limit:)
if @raw_messages.length > 1 if @raw_messages.length > 1
buffer = buffer = [
+"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{" +"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{",
]
upload_ids = []
@raw_messages[0..-2].each do |message| @raw_messages[0..-2].each do |message|
buffer << "\n" buffer << "\n"
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
if message[:type] == :user if message[:type] == :user
buffer << "#{message[:name] || "User"}: " buffer << "#{message[:name] || "User"}: "
else else
@ -209,16 +359,44 @@ module DiscourseAi
end end
last_message = @raw_messages[-1] last_message = @raw_messages[-1]
buffer << "#{last_message[:name] || "User"}: #{last_message[:content]} " buffer << "#{last_message[:name] || "User"}: "
buffer << last_message[:content]
buffer = compress_messages_buffer(buffer.flatten, max_uploads: MAX_CHAT_UPLOADS)
message = { type: :user, content: buffer } message = { type: :user, content: buffer }
upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present?
message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] ||
upload_ids if upload_ids.present?
[message] [message]
end end
# caps uploads to maximum uploads allowed in message stream
# and concats string elements
def compress_messages_buffer(buffer, max_uploads:)
compressed = []
current_text = +""
upload_count = 0
buffer.each do |item|
if item.is_a?(String)
current_text << item
elsif item.is_a?(Hash)
compressed << current_text if current_text.present?
compressed << item
current_text = +""
upload_count += 1
end
end
compressed << current_text if current_text.present?
if upload_count > max_uploads
counter = max_uploads - upload_count
compressed.delete_if { |item| item.is_a?(Hash) && (counter += 1) > 0 }
end
compressed = compressed[0] if compressed.length == 1 && compressed[0].is_a?(String)
compressed
end
end end
end end
end end

View File

@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(context.image_generation_scenario).to eq( expect(context.image_generation_scenario).to eq(
{ {
messages: [ messages: [
{ role: "user", parts: [{ text: "draw a cat" }] }, { role: "user", parts: [{ text: "user1: draw a cat" }] },
{ {
role: "model", role: "model",
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }], parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
], ],
}, },
{ role: "model", parts: { text: "Ok." } }, { role: "model", parts: { text: "Ok." } },
{ role: "user", parts: [{ text: "draw another cat" }] }, { role: "user", parts: [{ text: "user1: draw another cat" }] },
], ],
system_instruction: context.system_insts, system_instruction: context.system_insts,
}, },
@ -52,12 +52,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(context.multi_turn_scenario).to eq( expect(context.multi_turn_scenario).to eq(
{ {
messages: [ messages: [
{ role: "user", parts: [{ text: "This is a message by a user" }] }, { role: "user", parts: [{ text: "user1: This is a message by a user" }] },
{ {
role: "model", role: "model",
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }], parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
}, },
{ role: "user", parts: [{ text: "This is a new message by a user" }] }, { role: "user", parts: [{ text: "user1: This is a new message by a user" }] },
{ {
role: "model", role: "model",
parts: [ parts: [

View File

@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
) )
encoded = prompt.encoded_uploads(prompt.messages.last) encoded = prompt.encoded_uploads(prompt.messages.last)
@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
content = dialect.translate[1][:content] content = dialect.translate[1][:content]
expect(content).to eq( expect(content).to eq(
[{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }], [{ type: "text", text: "user1: hello" }, { type: "image_url", image_url: { url: image } }],
) )
end end

View File

@ -37,7 +37,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Nova do
it "properly formats messages with images" do it "properly formats messages with images" do
messages = [ messages = [
{ type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] }, {
type: :user,
id: "user1",
content: ["What's in this image?", { upload_id: upload.id }],
},
] ]
prompt = DiscourseAi::Completions::Prompt.new(messages: messages) prompt = DiscourseAi::Completions::Prompt.new(messages: messages)

View File

@ -34,8 +34,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
messages: [ messages: [
{ {
type: :user, type: :user,
content: "Describe this image in a single sentence.", content: ["Describe this image in a single sentence.", { upload_id: upload.id }],
upload_ids: [upload.id],
}, },
], ],
) )
@ -49,10 +48,15 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
expect(translated_messages.length).to eq(1) expect(translated_messages.length).to eq(1)
# no system message support here
expected_user_message = { expected_user_message = {
role: "user", role: "user",
content: [ content: [
{ type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") }, {
type: "text",
text:
"You are a bot specializing in image captioning.\nDescribe this image in a single sentence.",
},
{ {
type: "image_url", type: "image_url",
image_url: { image_url: {

View File

@ -271,7 +271,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
) )
encoded = prompt.encoded_uploads(prompt.messages.last) encoded = prompt.encoded_uploads(prompt.messages.last)
@ -283,6 +283,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
{ {
role: "user", role: "user",
content: [ content: [
{ type: "text", text: "user1: hello" },
{ {
type: "image", type: "image",
source: { source: {
@ -291,7 +292,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: encoded[0][:base64], data: encoded[0][:base64],
}, },
}, },
{ type: "text", text: "user1: hello" },
], ],
}, },
], ],

View File

@ -211,7 +211,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
) )
encoded = prompt.encoded_uploads(prompt.messages.last) encoded = prompt.encoded_uploads(prompt.messages.last)
@ -248,7 +248,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
{ {
"role" => "user", "role" => "user",
"parts" => [ "parts" => [
{ "text" => "hello" }, { "text" => "user1: hello" },
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } }, { "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
], ],
}, },

View File

@ -492,7 +492,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are image bot", "You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
) )
encoded = prompt.encoded_uploads(prompt.messages.last) encoded = prompt.encoded_uploads(prompt.messages.last)
@ -517,13 +517,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
{ {
role: "user", role: "user",
content: [ content: [
{ type: "text", text: "hello" },
{ {
type: "image_url", type: "image_url",
image_url: { image_url: {
url: "data:#{encoded[0][:mime_type]};base64,#{encoded[0][:base64]}", url: "data:#{encoded[0][:mime_type]};base64,#{encoded[0][:base64]}",
}, },
}, },
{ type: "text", text: "hello" },
], ],
name: "user1", name: "user1",
}, },

View File

@ -2,6 +2,36 @@
describe DiscourseAi::Completions::PromptMessagesBuilder do describe DiscourseAi::Completions::PromptMessagesBuilder do
let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new } let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new }
fab!(:user)
fab!(:admin)
fab!(:bot_user) { Fabricate(:user) }
fab!(:other_user) { Fabricate(:user) }
fab!(:image_upload1) do
Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
end
fab!(:image_upload2) do
Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
end
it "correctly merges user messages with uploads" do
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1])
builder.push(type: :user, content: "World", name: "Bob", upload_ids: [2])
messages = builder.to_a
# Check the structure of the merged message
expect(messages.length).to eq(1)
expect(messages[0][:type]).to eq(:user)
# The content should contain the text and both uploads
content = messages[0][:content]
expect(content).to be_an(Array)
expect(content[0]).to eq("Alice: Hello")
expect(content[1]).to eq({ upload_id: 1 })
expect(content[2]).to eq("\nBob: World")
expect(content[3]).to eq({ upload_id: 2 })
end
it "should allow merging user messages" do it "should allow merging user messages" do
builder.push(type: :user, content: "Hello", name: "Alice") builder.push(type: :user, content: "Hello", name: "Alice")
@ -14,7 +44,7 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2]) builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2])
expect(builder.to_a).to eq( expect(builder.to_a).to eq(
[{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }], [{ type: :user, content: ["Hello", { upload_id: 1 }, { upload_id: 2 }], name: "Alice" }],
) )
end end
@ -64,4 +94,319 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
expect(content).to include("Alice") expect(content).to include("Alice")
expect(content).to include("How do I solve this") expect(content).to include("How do I solve this")
end end
describe ".messages_from_chat" do
fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) }
fab!(:dm_message1) do
Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "Hello bot")
end
fab!(:dm_message2) do
Fabricate(:chat_message, chat_channel: dm_channel, user: bot_user, message: "Hello human")
end
fab!(:dm_message3) do
Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "How are you?")
end
fab!(:public_channel) { Fabricate(:category_channel) }
fab!(:public_message1) do
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Hello everyone")
end
fab!(:public_message2) do
Fabricate(:chat_message, chat_channel: public_channel, user: bot_user, message: "Hi there")
end
fab!(:thread_original) do
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Thread starter")
end
fab!(:thread) do
Fabricate(:chat_thread, channel: public_channel, original_message: thread_original)
end
fab!(:thread_reply1) do
Fabricate(
:chat_message,
chat_channel: public_channel,
user: other_user,
message: "Thread reply",
thread: thread,
)
end
fab!(:upload) { Fabricate(:upload, user: user) }
fab!(:message_with_upload) do
Fabricate(
:chat_message,
chat_channel: dm_channel,
user: user,
message: "Check this image",
upload_ids: [upload.id],
)
end
it "processes messages from direct message channels" do
context =
described_class.messages_from_chat(
dm_message3,
channel: dm_channel,
context_post_ids: nil,
max_messages: 10,
include_uploads: false,
bot_user_ids: [bot_user.id],
instruction_message: nil,
)
# this is all we got cause it is assuming threading
expect(context).to eq([{ type: :user, content: "How are you?", name: user.username }])
end
it "includes uploads when include_uploads is true" do
message_with_upload.reload
expect(message_with_upload.uploads).to include(upload)
context =
described_class.messages_from_chat(
message_with_upload,
channel: dm_channel,
context_post_ids: nil,
max_messages: 10,
include_uploads: true,
bot_user_ids: [bot_user.id],
instruction_message: nil,
)
# Find the message with upload
message = context.find { |m| m[:content] == ["Check this image", { upload_id: upload.id }] }
expect(message).to be_present
end
it "doesn't include uploads when include_uploads is false" do
# Make sure the upload is associated with the message
message_with_upload.reload
expect(message_with_upload.uploads).to include(upload)
context =
described_class.messages_from_chat(
message_with_upload,
channel: dm_channel,
context_post_ids: nil,
max_messages: 10,
include_uploads: false,
bot_user_ids: [bot_user.id],
instruction_message: nil,
)
# Find the message with upload
message = context.find { |m| m[:content] == "Check this image" }
expect(message).to be_present
expect(message[:upload_ids]).to be_nil
end
it "properly handles uploads in public channels with multiple users" do
_first_message =
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "First message")
_message_with_upload =
Fabricate(
:chat_message,
chat_channel: public_channel,
user: other_user,
message: "Message with image",
upload_ids: [upload.id],
)
last_message =
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Final message")
context =
described_class.messages_from_chat(
last_message,
channel: public_channel,
context_post_ids: nil,
max_messages: 3,
include_uploads: true,
bot_user_ids: [bot_user.id],
instruction_message: nil,
)
expect(context.length).to eq(1)
content = context.first[:content]
expect(content.length).to eq(3)
expect(content[0]).to include("First message")
expect(content[0]).to include("Message with image")
expect(content[1]).to include({ upload_id: upload.id })
expect(content[2]).to include("Final message")
end
end
describe ".messages_from_post" do
fab!(:pm) do
Fabricate(
:private_message_topic,
title: "This is my special PM",
user: user,
topic_allowed_users: [
Fabricate.build(:topic_allowed_user, user: user),
Fabricate.build(:topic_allowed_user, user: bot_user),
],
)
end
fab!(:first_post) do
Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user")
end
fab!(:second_post) do
Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply")
end
fab!(:third_post) do
Fabricate(
:post,
topic: pm,
user: user,
post_number: 3,
raw: "This is a second reply by the user",
)
end
it "handles uploads correctly in topic style messages" do
# Use Discourse's upload format in the post raw content
upload_markdown = "![test|658x372](#{image_upload1.short_url})"
post_with_upload =
Fabricate(
:post,
topic: pm,
user: admin,
raw: "This is the original #{upload_markdown} I just added",
)
UploadReference.create!(target: post_with_upload, upload: image_upload1)
upload2_markdown = "![test|658x372](#{image_upload2.short_url})"
post2_with_upload =
Fabricate(
:post,
topic: pm,
user: admin,
raw: "This post has a different image #{upload2_markdown} I just added",
)
UploadReference.create!(target: post2_with_upload, upload: image_upload2)
messages =
described_class.messages_from_post(
post2_with_upload,
style: :topic,
max_posts: 3,
bot_usernames: [bot_user.username],
include_uploads: true,
)
# this is not quite ideal yet, images are attached at the end of the post
# long term we may want to extract them out using a regex and create N parts
# so people can talk about multiple images in a single post
# this is the initial ground work though
expect(messages.length).to eq(1)
content = messages[0][:content]
# first part
# first image
# second part
# second image
expect(content.length).to eq(4)
expect(content[0]).to include("This is the original")
expect(content[1]).to eq({ upload_id: image_upload1.id })
expect(content[2]).to include("different image")
expect(content[3]).to eq({ upload_id: image_upload2.id })
end
context "with limited context" do
it "respects max_context_posts" do
context =
described_class.messages_from_post(
third_post,
max_posts: 1,
bot_usernames: [bot_user.username],
include_uploads: false,
)
expect(context).to contain_exactly(
*[{ type: :user, id: user.username, content: third_post.raw }],
)
end
end
it "includes previous posts ordered by post_number" do
context =
described_class.messages_from_post(
third_post,
max_posts: 10,
bot_usernames: [bot_user.username],
include_uploads: false,
)
expect(context).to eq(
[
{ type: :user, content: "This is a reply by the user", id: user.username },
{ type: :model, content: "This is a bot reply" },
{ type: :user, content: "This is a second reply by the user", id: user.username },
],
)
end
it "only include regular posts" do
first_post.update!(post_type: Post.types[:whisper])
context =
described_class.messages_from_post(
third_post,
max_posts: 10,
bot_usernames: [bot_user.username],
include_uploads: false,
)
# skips leading model reply which makes no sense cause first post was whisper
expect(context).to eq(
[{ type: :user, content: "This is a second reply by the user", id: user.username }],
)
end
context "with custom prompts" do
it "When post custom prompt is present, we use that instead of the post content" do
custom_prompt = [
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
"time",
"tool_call",
],
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
["I replied to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
context =
described_class.messages_from_post(
third_post,
max_posts: 10,
bot_usernames: [bot_user.username],
include_uploads: false,
)
expect(context).to eq(
[
{ type: :user, content: "This is a reply by the user", id: user.username },
{ type: :tool_call, content: custom_prompt.first.first, id: "time" },
{ type: :tool, id: "time", content: custom_prompt.second.first },
{ type: :model, content: custom_prompt.third.first },
{ type: :user, content: "This is a second reply by the user", id: user.username },
],
)
end
end
end
end end

View File

@ -25,34 +25,21 @@ RSpec.describe DiscourseAi::Completions::Prompt do
end end
describe "image support" do describe "image support" do
it "allows adding uploads to messages" do it "allows adding uploads inline in messages" do
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
prompt.max_pixels = 300 prompt.max_pixels = 300
prompt.push(type: :user, content: "hello", upload_ids: [upload.id]) prompt.push(
type: :user,
content: ["this is an image", { upload_id: upload.id }, "this was an image"],
)
expect(prompt.messages.last[:upload_ids]).to eq([upload.id]) encoded = prompt.content_with_encoded_uploads(prompt.messages.last[:content])
expect(prompt.max_pixels).to eq(300)
encoded = prompt.encoded_uploads(prompt.messages.last) expect(encoded.length).to eq(3)
expect(encoded[0]).to eq("this is an image")
expect(encoded.length).to eq(1) expect(encoded[1][:mime_type]).to eq("image/jpeg")
expect(encoded[0][:mime_type]).to eq("image/jpeg") expect(encoded[2]).to eq("this was an image")
old_base64 = encoded[0][:base64]
prompt.max_pixels = 1_000_000
encoded = prompt.encoded_uploads(prompt.messages.last)
expect(encoded.length).to eq(1)
expect(encoded[0][:mime_type]).to eq("image/jpeg")
new_base64 = encoded[0][:base64]
expect(new_base64.length).to be > old_base64.length
expect(new_base64.length).to be > 0
expect(old_base64.length).to be > 0
end end
end end

View File

@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new) bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new)
bot.reply( bot.reply(
{ conversation_context: [{ type: :user, content: "test" }] }, DiscourseAi::AiBot::BotContext.new(messages: [{ type: :user, content: "test" }]),
) do |_partial, _cancel, _placeholder| ) do |_partial, _cancel, _placeholder|
# we just need the block so bot has something to call with results # we just need the block so bot has something to call with results
end end
@ -74,7 +74,10 @@ RSpec.describe DiscourseAi::AiBot::Bot do
HTML HTML
context = { conversation_context: [{ type: :user, content: "Does my site has tags?" }] } context =
DiscourseAi::AiBot::BotContext.new(
messages: [{ type: :user, content: "Does my site has tags?" }],
)
DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do
bot.reply(context) do |_bot_reply_post, cancel, placeholder| bot.reply(context) do |_bot_reply_post, cancel, placeholder|

View File

@ -36,13 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end end
let(:context) do let(:context) do
{ DiscourseAi::AiBot::BotContext.new(
site_url: Discourse.base_url, site_url: Discourse.base_url,
site_title: "test site title", site_title: "test site title",
site_description: "test site description", site_description: "test site description",
time: Time.zone.now, time: Time.zone.now,
participants: topic_with_users.allowed_users.map(&:username).join(", "), participants: topic_with_users.allowed_users.map(&:username).join(", "),
} )
end end
fab!(:admin) fab!(:admin)
@ -307,7 +307,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new } let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
let(:with_cc) do let(:with_cc) do
context.merge(conversation_context: [{ content: "Tell me the time", type: :user }]) context.messages = [{ content: "Tell me the time", type: :user }]
context
end end
context "when a persona has no uploads" do context "when a persona has no uploads" do
@ -345,17 +346,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
# this means that we will consolidate # this means that we will consolidate
ctx = context.messages = [
with_cc.merge( { content: "Tell me the time", type: :user },
conversation_context: [ { content: "the time is 1", type: :model },
{ content: "Tell me the time", type: :user }, { content: "in france?", type: :user },
{ content: "the time is 1", type: :model }, ]
{ content: "in france?", type: :user },
],
)
DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do
custom_persona.craft_prompt(ctx).messages.first[:content] custom_persona.craft_prompt(context).messages.first[:content]
end end
message = message =
@ -397,7 +395,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id]) UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
EmbeddingsGenerationStubs.hugging_face_service( EmbeddingsGenerationStubs.hugging_face_service(
with_cc.dig(:conversation_context, 0, :content), with_cc.messages.dig(0, :content),
prompt_cc_embeddings, prompt_cc_embeddings,
) )
end end

View File

@ -267,7 +267,10 @@ RSpec.describe DiscourseAi::AiBot::Playground do
prompts = inner_prompts prompts = inner_prompts
end end
expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id]) content = prompts[0].messages[1][:content]
expect(content).to include({ upload_id: upload.id })
expect(prompts[0].max_pixels).to eq(1000) expect(prompts[0].max_pixels).to eq(1000)
post.topic.reload post.topic.reload
@ -1154,79 +1157,4 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(playground.available_bot_usernames).to include(persona.user.username) expect(playground.available_bot_usernames).to include(persona.user.username)
end end
end end
describe "#conversation_context" do
context "with limited context" do
before do
@old_persona = playground.bot.persona
persona = Fabricate(:ai_persona, max_context_posts: 1)
playground.bot.persona = persona.class_instance.new
end
after { playground.bot.persona = @old_persona }
it "respects max_context_post" do
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[{ type: :user, id: user.username, content: third_post.raw }],
)
end
end
xit "includes previous posts ordered by post_number" do
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: :user, id: user.username, content: third_post.raw },
{ type: :model, content: second_post.raw },
{ type: :user, id: user.username, content: first_post.raw },
],
)
end
xit "only include regular posts" do
first_post.update!(post_type: Post.types[:whisper])
context = playground.conversation_context(third_post)
# skips leading model reply which makes no sense cause first post was whisper
expect(context).to contain_exactly(
*[{ type: :user, id: user.username, content: third_post.raw }],
)
end
context "with custom prompts" do
it "When post custom prompt is present, we use that instead of the post content" do
custom_prompt = [
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
"time",
"tool_call",
],
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
["I replied to the time command", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: :user, id: user.username, content: first_post.raw },
{ type: :tool_call, content: custom_prompt.first.first, id: "time" },
{ type: :tool, id: "time", content: custom_prompt.second.first },
{ type: :model, content: custom_prompt.third.first },
{ type: :user, id: user.username, content: third_post.raw },
],
)
end
end
end
end end

View File

@ -34,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::CreateArtifact do
{ html_body: "hello" }, { html_body: "hello" },
bot_user: Fabricate(:user), bot_user: Fabricate(:user),
llm: llm, llm: llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post.id,
},
) )
tool.parameters = { name: "hello", specification: "hello spec" } tool.parameters = { name: "hello", specification: "hello spec" }

View File

@ -15,9 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
let(:dall_e) do let(:dall_e) { described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user) }
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
end
let(:base64_image) do let(:base64_image) do
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
@ -30,8 +28,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
{ prompts: ["a cat"], aspect_ratio: "tall" }, { prompts: ["a cat"], aspect_ratio: "tall" },
llm: llm, llm: llm,
bot_user: bot_user, bot_user: bot_user,
context: {
},
) )
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }] data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]

View File

@ -9,8 +9,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
{ prompts: prompts, seeds: [99, 32] }, { prompts: prompts, seeds: [99, 32] },
bot_user: bot_user, bot_user: bot_user,
llm: llm, llm: llm,
context: { context: DiscourseAi::AiBot::BotContext.new,
},
) )
end end

View File

@ -25,9 +25,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" }, { url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post2.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -46,9 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "invalid-url" }, { url: "invalid-url" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -62,9 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/99999" }, { url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/99999" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -97,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "https://example.com" }, { url: "https://example.com" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -128,9 +120,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
{ url: "https://example.com" }, { url: "https://example.com" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}

View File

@ -56,9 +56,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
persona_options: { persona_options: {
"read_private" => true, "read_private" => true,
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(user: admin),
user: admin,
},
) )
results = tool.invoke results = tool.invoke
expect(results[:content]).to include("hello there") expect(results[:content]).to include("hello there")
@ -68,9 +66,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] }, { topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
bot_user: bot_user, bot_user: bot_user,
llm: llm, llm: llm,
context: { context: DiscourseAi::AiBot::BotContext.new(user: admin),
user: admin,
},
) )
results = tool.invoke results = tool.invoke

View File

@ -60,9 +60,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
persona_options: persona_options, persona_options: persona_options,
bot_user: bot_user, bot_user: bot_user,
llm: llm, llm: llm,
context: { context: DiscourseAi::AiBot::BotContext.new(user: user),
user: user,
},
) )
expect(search.options[:base_query]).to eq("#funny") expect(search.options[:base_query]).to eq("#funny")

View File

@ -47,9 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: { persona_options: {
"update_algorithm" => "full", "update_algorithm" => "full",
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -93,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: { persona_options: {
"update_algorithm" => "full", "update_algorithm" => "full",
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -119,9 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Invalid update" }, { artifact_id: artifact.id, instructions: "Invalid update" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -135,9 +129,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: -1, instructions: "Update something" }, { artifact_id: -1, instructions: "Update something" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -163,9 +155,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: { persona_options: {
"update_algorithm" => "full", "update_algorithm" => "full",
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
tool.invoke {} tool.invoke {}
@ -196,9 +186,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: { persona_options: {
"update_algorithm" => "full", "update_algorithm" => "full",
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
.invoke {} .invoke {}
end end
@ -224,9 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
persona_options: { persona_options: {
"update_algorithm" => "full", "update_algorithm" => "full",
}, },
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
) )
result = tool.invoke {} result = tool.invoke {}
@ -276,9 +262,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" }, { artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
persona_options: { persona_options: {
"update_algorithm" => "diff", "update_algorithm" => "diff",
}, },
@ -346,9 +330,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" }, { artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
bot_user: bot_user, bot_user: bot_user,
llm: llm_model.to_llm, llm: llm_model.to_llm,
context: { context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
post_id: post.id,
},
persona_options: { persona_options: {
"update_algorithm" => "diff", "update_algorithm" => "diff",
}, },

View File

@ -255,11 +255,12 @@ RSpec.describe DiscourseAi::AiModeration::SpamScanner do
prompt = _prompts.first prompt = _prompts.first
end end
content = prompt.messages[1][:content] # its an array so lets just stringify it to make testing easier
content = prompt.messages[1][:content][0]
expect(content).to include(post.topic.title) expect(content).to include(post.topic.title)
expect(content).to include(post.raw) expect(content).to include(post.raw)
upload_ids = prompt.messages[1][:upload_ids] upload_ids = prompt.messages[1][:content].map { |m| m[:upload_id] if m.is_a?(Hash) }.compact
expect(upload_ids).to be_present expect(upload_ids).to be_present
expect(upload_ids).to eq(post.upload_ids) expect(upload_ids).to eq(post.upload_ids)

View File

@ -199,7 +199,7 @@ describe DiscourseAi::Automation::LlmTriage do
triage_prompt = DiscourseAi::Completions::Llm.prompts.last triage_prompt = DiscourseAi::Completions::Llm.prompts.last
expect(triage_prompt.messages.last[:upload_ids]).to contain_exactly(post_upload.id) expect(triage_prompt.messages.last[:content].last).to eq({ upload_id: post_upload.id })
end end
end end

View File

@ -5,6 +5,7 @@ RSpec.describe AiTool do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
fab!(:topic) fab!(:topic)
fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") } fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") }
fab!(:bot_user) { Discourse.system_user }
def create_tool( def create_tool(
parameters: nil, parameters: nil,
@ -16,7 +17,8 @@ RSpec.describe AiTool do
name: "test #{SecureRandom.uuid}", name: "test #{SecureRandom.uuid}",
tool_name: "test_#{SecureRandom.uuid.underscore}", tool_name: "test_#{SecureRandom.uuid.underscore}",
description: "test", description: "test",
parameters: parameters || [{ name: "query", type: "string", desciption: "perform a search" }], parameters:
parameters || [{ name: "query", type: "string", description: "perform a search" }],
script: script || "function invoke(params) { return params; }", script: script || "function invoke(params) { return params; }",
created_by_id: 1, created_by_id: 1,
summary: "Test tool summary", summary: "Test tool summary",
@ -32,11 +34,11 @@ RSpec.describe AiTool do
{ {
name: tool.tool_name, name: tool.tool_name,
description: "test", description: "test",
parameters: [{ name: "query", type: "string", desciption: "perform a search" }], parameters: [{ name: "query", type: "string", description: "perform a search" }],
}, },
) )
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
expect(runner.invoke).to eq("query" => "test") expect(runner.invoke).to eq("query" => "test")
end end
@ -57,7 +59,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil)
stub_request(verb, "https://example.com/api").with( stub_request(verb, "https://example.com/api").with(
body: "{\"data\":\"test data\"}", body: "{\"data\":\"test data\"}",
@ -83,7 +85,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/test").with( stub_request(:get, "https://example.com/test").with(
headers: { headers: {
@ -110,7 +112,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({}, llm: nil, bot_user: nil, context: {}) runner = tool.runner({}, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/").to_return( stub_request(:get, "https://example.com/").to_return(
status: 200, status: 200,
@ -134,7 +136,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/test").with( stub_request(:get, "https://example.com/test").with(
headers: { headers: {
@ -160,13 +162,16 @@ RSpec.describe AiTool do
} }
JS JS
tool = create_tool(script: script)
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
stub_request(:get, "https://example.com/test").to_return do stub_request(:get, "https://example.com/test").to_return do
sleep 0.01 sleep 0.01
{ status: 200, body: "Hello World", headers: {} } { status: 200, body: "Hello World", headers: {} }
end end
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
runner.timeout = 10 runner.timeout = 10
@ -184,7 +189,7 @@ RSpec.describe AiTool do
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({}, llm: llm, bot_user: nil, context: {}) runner = tool.runner({}, llm: llm, bot_user: nil)
result = runner.invoke result = runner.invoke
expect(result).to eq("Hello") expect(result).to eq("Hello")
@ -209,7 +214,7 @@ RSpec.describe AiTool do
responses = ["Hello ", "World"] responses = ["Hello ", "World"]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
runner = tool.runner({}, llm: llm, bot_user: nil, context: {}) runner = tool.runner({}, llm: llm, bot_user: nil)
result = runner.invoke result = runner.invoke
prompts = _prompts prompts = _prompts
end end
@ -232,7 +237,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
runner.timeout = 5 runner.timeout = 5
@ -295,7 +300,7 @@ RSpec.describe AiTool do
RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id]) RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id])
result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke result = tool.runner({}, llm: nil, bot_user: nil).invoke
expected = [ expected = [
[{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }], [{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }],
@ -316,7 +321,7 @@ RSpec.describe AiTool do
# this part of the API is a bit awkward, maybe we should do it # this part of the API is a bit awkward, maybe we should do it
# automatically # automatically
RagDocumentFragment.update_target_uploads(tool, [upload1.id, upload2.id]) RagDocumentFragment.update_target_uploads(tool, [upload1.id, upload2.id])
result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke result = tool.runner({}, llm: nil, bot_user: nil).invoke
expected = [ expected = [
[{ "fragment" => "48 49 50", "metadata" => nil }], [{ "fragment" => "48 49 50", "metadata" => nil }],
@ -340,7 +345,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil)
result = runner.invoke result = runner.invoke
@ -364,7 +369,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil)
result = runner.invoke result = runner.invoke
post_hash = result["post"] post_hash = result["post"]
@ -393,7 +398,7 @@ RSpec.describe AiTool do
JS JS
tool = create_tool(script: script) tool = create_tool(script: script)
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {}) runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil)
result = runner.invoke result = runner.invoke
@ -401,4 +406,158 @@ RSpec.describe AiTool do
expect(result["rows"].first["title"]).to eq(topic.title) expect(result["rows"].first["title"]).to eq(topic.title)
end end
end end
context "when using the chat API" do
before(:each) do
skip "Chat plugin tests skipped because Chat module is not defined." unless defined?(Chat)
SiteSetting.chat_enabled = true
end
fab!(:chat_user) { Fabricate(:user) }
fab!(:chat_channel) do
Fabricate(:chat_channel).tap do |channel|
Fabricate(
:user_chat_channel_membership,
user: chat_user,
chat_channel: channel,
following: true,
)
end
end
it "can create a chat message" do
script = <<~JS
function invoke(params) {
return discourse.createChatMessage({
channel_name: params.channel_name,
username: params.username,
message: params.message
});
}
JS
tool = create_tool(script: script)
runner =
tool.runner(
{
"channel_name" => chat_channel.name,
"username" => chat_user.username,
"message" => "Hello from the tool!",
},
llm: nil,
bot_user: bot_user, # The user *running* the tool doesn't affect sender
)
initial_message_count = Chat::Message.count
result = runner.invoke
expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
expect(result["message"]).to eq("Hello from the tool!")
expect(result["created_at"]).to be_present
expect(result).not_to have_key("error")
# Verify message was actually created in the database
expect(Chat::Message.count).to eq(initial_message_count + 1)
created_message = Chat::Message.find_by(id: result["message_id"])
expect(created_message).not_to be_nil
expect(created_message.message).to eq("Hello from the tool!")
expect(created_message.user_id).to eq(chat_user.id) # Message is sent AS the specified user
expect(created_message.chat_channel_id).to eq(chat_channel.id)
end
it "can create a chat message using channel slug" do
chat_channel.update!(name: "My Test Channel", slug: "my-test-channel")
expect(chat_channel.slug).to eq("my-test-channel")
script = <<~JS
function invoke(params) {
return discourse.createChatMessage({
channel_name: params.channel_slug, // Using slug here
username: params.username,
message: params.message
});
}
JS
tool = create_tool(script: script)
runner =
tool.runner(
{
"channel_slug" => chat_channel.slug,
"username" => chat_user.username,
"message" => "Hello via slug!",
},
llm: nil,
bot_user: bot_user,
)
result = runner.invoke
expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
# see: https://github.com/rubyjs/mini_racer/issues/348
# expect(result["message_id"]).to be_a(Integer)
created_message = Chat::Message.find_by(id: result["message_id"])
expect(created_message).not_to be_nil
expect(created_message.message).to eq("Hello via slug!")
expect(created_message.chat_channel_id).to eq(chat_channel.id)
end
it "returns an error if the channel is not found" do
script = <<~JS
function invoke(params) {
return discourse.createChatMessage({
channel_name: "non_existent_channel",
username: params.username,
message: params.message
});
}
JS
tool = create_tool(script: script)
runner =
tool.runner(
{ "username" => chat_user.username, "message" => "Test" },
llm: nil,
bot_user: bot_user,
)
initial_message_count = Chat::Message.count
expect { runner.invoke }.to raise_error(
MiniRacer::RuntimeError,
/Channel not found: non_existent_channel/,
)
expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
end
it "returns an error if the user is not found" do
script = <<~JS
function invoke(params) {
return discourse.createChatMessage({
channel_name: params.channel_name,
username: "non_existent_user",
message: params.message
});
}
JS
tool = create_tool(script: script)
runner =
tool.runner(
{ "channel_name" => chat_channel.name, "message" => "Test" },
llm: nil,
bot_user: bot_user,
)
initial_message_count = Chat::Message.count
expect { runner.invoke }.to raise_error(
MiniRacer::RuntimeError,
/User not found: non_existent_user/,
)
expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
end
end
end end