mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-13 09:33:28 +00:00
FEATURE: flexible image handling within messages (#1214)
* DEV: refactor bot internals This introduces a proper object for bot context, this makes it simpler to improve context management as we go cause we have a nice object to work with Starts refactoring allowing for a single message to have multiple uploads throughout * transplant method to message builder * chipping away at inline uploads * image support is improved but not fully fixed yet partially working in anthropic, still got quite a few dialects to go * open ai and claude are now working * Gemini is now working as well * fix nova * more dialects... * fix ollama * fix specs * update artifact fixed * more tests * spam scanner * pass more specs * bunch of specs improved * more bug fixes. * all the rest of the tests are working * improve tests coverage and ensure custom tools are aware of new context object * tests are working, but we need more tests * resolve merge conflict * new preamble and expanded specs on ai tool * remove concept of "standalone tools" This is no longer needed, we can set custom raw, tool details are injected into tool calls
This commit is contained in:
parent
f3e78f0d80
commit
5b6d39a206
@ -55,7 +55,7 @@ module DiscourseAi
|
||||
# we need an llm so we have a tokenizer
|
||||
# but will do without if none is available
|
||||
llm = LlmModel.first&.to_llm
|
||||
runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user, context: {})
|
||||
runner = @ai_tool.runner(parameters, llm: llm, bot_user: current_user)
|
||||
result = runner.invoke
|
||||
|
||||
if result.is_a?(Hash) && result[:error]
|
||||
|
@ -30,9 +30,13 @@ module Jobs
|
||||
|
||||
base = { query: query, model_used: llm_model.display_name }
|
||||
|
||||
bot.reply(
|
||||
{ conversation_context: [{ type: :user, content: query }], skip_tool_details: true },
|
||||
) do |partial|
|
||||
context =
|
||||
DiscourseAi::AiBot::BotContext.new(
|
||||
messages: [{ type: :user, content: query }],
|
||||
skip_tool_details: true,
|
||||
)
|
||||
|
||||
bot.reply(context) do |partial|
|
||||
streamed_reply << partial
|
||||
|
||||
# Throttle updates.
|
||||
|
@ -35,7 +35,7 @@ class AiTool < ActiveRecord::Base
|
||||
tool_name.presence || name
|
||||
end
|
||||
|
||||
def runner(parameters, llm:, bot_user:, context: {})
|
||||
def runner(parameters, llm:, bot_user:, context: nil)
|
||||
DiscourseAi::AiBot::ToolRunner.new(
|
||||
parameters: parameters,
|
||||
llm: llm,
|
||||
@ -59,86 +59,166 @@ class AiTool < ActiveRecord::Base
|
||||
|
||||
def self.preamble
|
||||
<<~JS
|
||||
/**
|
||||
* Tool API Quick Reference
|
||||
*
|
||||
* Entry Functions
|
||||
*
|
||||
* invoke(parameters): Main function. Receives parameters (Object). Must return a JSON-serializable value.
|
||||
* Example:
|
||||
* function invoke(parameters) { return "result"; }
|
||||
*
|
||||
* details(): Optional. Returns a string describing the tool.
|
||||
* Example:
|
||||
* function details() { return "Tool description."; }
|
||||
*
|
||||
* Provided Objects
|
||||
*
|
||||
* 1. http
|
||||
* http.get(url, options?): Performs an HTTP GET request.
|
||||
* Parameters:
|
||||
* url (string): The request URL.
|
||||
* options (Object, optional):
|
||||
* headers (Object): Request headers.
|
||||
* Returns:
|
||||
* { status: number, body: string }
|
||||
*
|
||||
* http.post(url, options?): Performs an HTTP POST request.
|
||||
* Parameters:
|
||||
* url (string): The request URL.
|
||||
* options (Object, optional):
|
||||
* headers (Object): Request headers.
|
||||
* body (string): Request body.
|
||||
* Returns:
|
||||
* { status: number, body: string }
|
||||
*
|
||||
* (also available: http.put, http.patch, http.delete)
|
||||
*
|
||||
* Note: Max 20 HTTP requests per execution.
|
||||
*
|
||||
* 2. llm
|
||||
* llm.truncate(text, length): Truncates text to a specified token length.
|
||||
* Parameters:
|
||||
* text (string): Text to truncate.
|
||||
* length (number): Max tokens.
|
||||
* Returns:
|
||||
* Truncated string.
|
||||
*
|
||||
* 3. index
|
||||
* index.search(query, options?): Searches indexed documents.
|
||||
* Parameters:
|
||||
* query (string): Search query.
|
||||
* options (Object, optional):
|
||||
* filenames (Array): Limit search to specific files.
|
||||
* limit (number): Max fragments (up to 200).
|
||||
* Returns:
|
||||
* Array of { fragment: string, metadata: string }
|
||||
*
|
||||
* 4. upload
|
||||
* upload.create(filename, base_64_content): Uploads a file.
|
||||
* Parameters:
|
||||
* filename (string): Name of the file.
|
||||
* base_64_content (string): Base64 encoded file content.
|
||||
* Returns:
|
||||
* { id: number, short_url: string }
|
||||
*
|
||||
* 5. chain
|
||||
* chain.setCustomRaw(raw): Sets the body of the post and exist chain.
|
||||
* Parameters:
|
||||
* raw (string): raw content to add to post.
|
||||
*
|
||||
* Constraints
|
||||
*
|
||||
* Execution Time: ≤ 2000ms
|
||||
* Memory: ≤ 10MB
|
||||
* HTTP Requests: ≤ 20 per execution
|
||||
* Exceeding limits will result in errors or termination.
|
||||
*
|
||||
* Security
|
||||
*
|
||||
* Sandboxed Environment: No access to system or global objects.
|
||||
* No File System Access: Cannot read or write files.
|
||||
*/
|
||||
/**
|
||||
* Tool API Quick Reference
|
||||
*
|
||||
* Entry Functions
|
||||
*
|
||||
* invoke(parameters): Main function. Receives parameters defined in the tool's signature (Object).
|
||||
* Must return a JSON-serializable value (e.g., string, number, object, array).
|
||||
* Example:
|
||||
* function invoke(parameters) { return { result: "Data processed", input: parameters.query }; }
|
||||
*
|
||||
* details(): Optional function. Returns a string (can include basic HTML) describing
|
||||
* the tool's action after invocation, often using data from the invocation.
|
||||
* This is displayed in the chat interface.
|
||||
* Example:
|
||||
* let lastUrl;
|
||||
* function invoke(parameters) {
|
||||
* lastUrl = parameters.url;
|
||||
* // ... perform action ...
|
||||
* return { success: true, content: "..." };
|
||||
* }
|
||||
* function details() {
|
||||
* return `Browsed: <a href="${lastUrl}">${lastUrl}</a>`;
|
||||
* }
|
||||
*
|
||||
* Provided Objects & Functions
|
||||
*
|
||||
* 1. http
|
||||
* Performs HTTP requests. Max 20 requests per execution.
|
||||
*
|
||||
* http.get(url, options?): Performs GET request.
|
||||
* Parameters:
|
||||
* url (string): The request URL.
|
||||
* options (Object, optional):
|
||||
* headers (Object): Request headers (e.g., { "Authorization": "Bearer key" }).
|
||||
* Returns: { status: number, body: string }
|
||||
*
|
||||
* http.post(url, options?): Performs POST request.
|
||||
* Parameters:
|
||||
* url (string): The request URL.
|
||||
* options (Object, optional):
|
||||
* headers (Object): Request headers.
|
||||
* body (string | Object): Request body. If an object, it's stringified as JSON.
|
||||
* Returns: { status: number, body: string }
|
||||
*
|
||||
* http.put(url, options?): Performs PUT request (similar to POST).
|
||||
* http.patch(url, options?): Performs PATCH request (similar to POST).
|
||||
* http.delete(url, options?): Performs DELETE request (similar to GET/POST).
|
||||
*
|
||||
* 2. llm
|
||||
* Interacts with the Language Model.
|
||||
*
|
||||
* llm.truncate(text, length): Truncates text to a specified token length based on the configured LLM's tokenizer.
|
||||
* Parameters:
|
||||
* text (string): Text to truncate.
|
||||
* length (number): Maximum number of tokens.
|
||||
* Returns: string (truncated text)
|
||||
*
|
||||
* llm.generate(prompt): Generates text using the configured LLM associated with the tool runner.
|
||||
* Parameters:
|
||||
* prompt (string | Object): The prompt. Can be a simple string or an object
|
||||
* like { messages: [{ type: "system", content: "..." }, { type: "user", content: "..." }] }.
|
||||
* Returns: string (generated text)
|
||||
*
|
||||
* 3. index
|
||||
* Searches attached RAG (Retrieval-Augmented Generation) documents linked to this tool.
|
||||
*
|
||||
* index.search(query, options?): Searches indexed document fragments.
|
||||
* Parameters:
|
||||
* query (string): The search query used for semantic search.
|
||||
* options (Object, optional):
|
||||
* filenames (Array<string>): Filter search to fragments from specific uploaded filenames.
|
||||
* limit (number): Maximum number of fragments to return (default: 10, max: 200).
|
||||
* Returns: Array<{ fragment: string, metadata: string | null }> - Ordered by relevance.
|
||||
*
|
||||
* 4. upload
|
||||
* Handles file uploads within Discourse.
|
||||
*
|
||||
* upload.create(filename, base_64_content): Uploads a file created by the tool, making it available in Discourse.
|
||||
* Parameters:
|
||||
* filename (string): The desired name for the file (basename is used for security).
|
||||
* base_64_content (string): Base64 encoded content of the file.
|
||||
* Returns: { id: number, url: string, short_url: string } - Details of the created upload record.
|
||||
*
|
||||
* 5. chain
|
||||
* Controls the execution flow.
|
||||
*
|
||||
* chain.setCustomRaw(raw): Sets the final raw content of the bot's post and immediately
|
||||
* stops the tool execution chain. Useful for tools that directly
|
||||
* generate the full response content (e.g., image generation tools attaching the image markdown).
|
||||
* Parameters:
|
||||
* raw (string): The raw Markdown content for the post.
|
||||
* Returns: void
|
||||
*
|
||||
* 6. discourse
|
||||
* Interacts with Discourse specific features. Access is generally performed as the SystemUser.
|
||||
*
|
||||
* discourse.search(params): Performs a Discourse search.
|
||||
* Parameters:
|
||||
* params (Object): Search parameters (e.g., { search_query: "keyword", with_private: true, max_results: 10 }).
|
||||
* `with_private: true` searches across all posts visible to the SystemUser. `result_style: 'detailed'` is used by default.
|
||||
* Returns: Object (Discourse search results structure, includes posts, topics, users etc.)
|
||||
*
|
||||
* discourse.getPost(post_id): Retrieves details for a specific post.
|
||||
* Parameters:
|
||||
* post_id (number): The ID of the post.
|
||||
* Returns: Object (Post details including `raw`, nested `topic` object with ListableTopicSerializer structure) or null if not found/accessible.
|
||||
*
|
||||
* discourse.getTopic(topic_id): Retrieves details for a specific topic.
|
||||
* Parameters:
|
||||
* topic_id (number): The ID of the topic.
|
||||
* Returns: Object (Topic details using ListableTopicSerializer structure) or null if not found/accessible.
|
||||
*
|
||||
* discourse.getUser(user_id_or_username): Retrieves details for a specific user.
|
||||
* Parameters:
|
||||
* user_id_or_username (number | string): The ID or username of the user.
|
||||
* Returns: Object (User details using UserSerializer structure) or null if not found.
|
||||
*
|
||||
* discourse.getPersona(name): Gets an object representing another AI Persona configured on the site.
|
||||
* Parameters:
|
||||
* name (string): The name of the target persona.
|
||||
* Returns: Object { respondTo: function(params) } or null if persona not found.
|
||||
* respondTo(params): Instructs the target persona to generate a response within the current context (e.g., replying to the same post or chat message).
|
||||
* Parameters:
|
||||
* params (Object, optional): { instructions: string, whisper: boolean }
|
||||
* Returns: { success: boolean, post_id?: number, post_number?: number, message_id?: number } or { error: string }
|
||||
*
|
||||
* discourse.createChatMessage(params): Creates a new message in a Discourse Chat channel.
|
||||
* Parameters:
|
||||
* params (Object): { channel_name: string, username: string, message: string }
|
||||
* `channel_name` can be the channel name or slug.
|
||||
* `username` specifies the user who should appear as the sender. The user must exist.
|
||||
* The sending user must have permission to post in the channel.
|
||||
* Returns: { success: boolean, message_id?: number, message?: string, created_at?: string } or { error: string }
|
||||
*
|
||||
* 7. context
|
||||
* An object containing information about the environment where the tool is being run.
|
||||
* Available properties depend on the invocation context, but may include:
|
||||
* post_id (number): ID of the post triggering the tool (if in a Post context).
|
||||
* topic_id (number): ID of the topic (if in a Post context).
|
||||
* private_message (boolean): Whether the context is a private message (in Post context).
|
||||
* message_id (number): ID of the chat message triggering the tool (if in Chat context).
|
||||
* channel_id (number): ID of the chat channel (if in Chat context).
|
||||
* user (Object): Details of the user invoking the tool/persona (structure may vary, often null or SystemUser details unless explicitly passed).
|
||||
* participants (string): Comma-separated list of usernames in a PM (if applicable).
|
||||
* // ... other potential context-specific properties added by the calling environment.
|
||||
*
|
||||
* Constraints
|
||||
*
|
||||
* Execution Time: ≤ 2000ms (default timeout in milliseconds) - This timer *pauses* during external HTTP requests or LLM calls initiated via `http.*` or `llm.generate`, but applies to the script's own processing time.
|
||||
* Memory: ≤ 10MB (V8 heap limit)
|
||||
* Stack Depth: ≤ 20 (Marshal stack depth limit for Ruby interop)
|
||||
* HTTP Requests: ≤ 20 per execution
|
||||
* Exceeding limits will result in errors or termination (e.g., timeout error, out-of-memory error, TooManyRequestsError).
|
||||
*
|
||||
* Security
|
||||
*
|
||||
* Sandboxed Environment: The script runs in a restricted V8 JavaScript environment (via MiniRacer).
|
||||
* No direct access to browser or environment, browser globals (like `window` or `document`), or the host system's file system.
|
||||
* Network requests are proxied through the Discourse backend, not made directly from the sandbox.
|
||||
*/
|
||||
JS
|
||||
end
|
||||
|
||||
|
@ -75,9 +75,9 @@ module DiscourseAi
|
||||
def force_tool_if_needed(prompt, context)
|
||||
return if prompt.tool_choice == :none
|
||||
|
||||
context[:chosen_tools] ||= []
|
||||
context.chosen_tools ||= []
|
||||
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||
force_tool = forced_tools.find { |name| !context.chosen_tools.include?(name) }
|
||||
|
||||
if force_tool && persona.forced_tool_count > 0
|
||||
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
|
||||
@ -85,7 +85,7 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
if force_tool
|
||||
context[:chosen_tools] << force_tool
|
||||
context.chosen_tools << force_tool
|
||||
prompt.tool_choice = force_tool
|
||||
else
|
||||
prompt.tool_choice = nil
|
||||
@ -93,6 +93,9 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def reply(context, &update_blk)
|
||||
unless context.is_a?(BotContext)
|
||||
raise ArgumentError, "context must be an instance of BotContext"
|
||||
end
|
||||
llm = DiscourseAi::Completions::Llm.proxy(model)
|
||||
prompt = persona.craft_prompt(context, llm: llm)
|
||||
|
||||
@ -100,7 +103,7 @@ module DiscourseAi
|
||||
ongoing_chain = true
|
||||
raw_context = []
|
||||
|
||||
user = context[:user]
|
||||
user = context.user
|
||||
|
||||
llm_kwargs = { user: user }
|
||||
llm_kwargs[:temperature] = persona.temperature if persona.temperature
|
||||
@ -277,27 +280,15 @@ module DiscourseAi
|
||||
name: tool.name,
|
||||
}
|
||||
|
||||
if tool.standalone?
|
||||
standalone_context =
|
||||
context.dup.merge(
|
||||
conversation_context: [
|
||||
context[:conversation_context].last,
|
||||
tool_call_message,
|
||||
tool_message,
|
||||
],
|
||||
)
|
||||
prompt = persona.craft_prompt(standalone_context)
|
||||
else
|
||||
prompt.push(**tool_call_message)
|
||||
prompt.push(**tool_message)
|
||||
end
|
||||
prompt.push(**tool_call_message)
|
||||
prompt.push(**tool_message)
|
||||
|
||||
raw_context << [tool_call_message[:content], tool_call_id, "tool_call", tool.name]
|
||||
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
|
||||
end
|
||||
|
||||
def invoke_tool(tool, llm, cancel, context, &update_blk)
|
||||
show_placeholder = !context[:skip_tool_details] && !tool.class.allow_partial_tool_calls?
|
||||
show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
|
||||
|
||||
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
|
||||
|
||||
|
107
lib/ai_bot/bot_context.rb
Normal file
107
lib/ai_bot/bot_context.rb
Normal file
@ -0,0 +1,107 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module AiBot
|
||||
class BotContext
|
||||
attr_accessor :messages,
|
||||
:topic_id,
|
||||
:post_id,
|
||||
:private_message,
|
||||
:custom_instructions,
|
||||
:user,
|
||||
:skip_tool_details,
|
||||
:participants,
|
||||
:chosen_tools,
|
||||
:message_id,
|
||||
:channel_id,
|
||||
:context_post_ids
|
||||
|
||||
def initialize(
|
||||
post: nil,
|
||||
participants: nil,
|
||||
user: nil,
|
||||
skip_tool_details: nil,
|
||||
messages: [],
|
||||
custom_instructions: nil,
|
||||
site_url: nil,
|
||||
site_title: nil,
|
||||
site_description: nil,
|
||||
time: nil,
|
||||
message_id: nil,
|
||||
channel_id: nil,
|
||||
context_post_ids: nil
|
||||
)
|
||||
@participants = participants
|
||||
@user = user
|
||||
@skip_tool_details = skip_tool_details
|
||||
@messages = messages
|
||||
@custom_instructions = custom_instructions
|
||||
|
||||
@message_id = message_id
|
||||
@channel_id = channel_id
|
||||
@context_post_ids = context_post_ids
|
||||
|
||||
@site_url = site_url
|
||||
@site_title = site_title
|
||||
@site_description = site_description
|
||||
@time = time
|
||||
|
||||
if post
|
||||
@post_id = post.id
|
||||
@topic_id = post.topic_id
|
||||
@private_message = post.topic.private_message?
|
||||
@participants ||= post.topic.allowed_users.map(&:username).join(", ") if @private_message
|
||||
@user = post.user
|
||||
end
|
||||
end
|
||||
|
||||
# these are strings that can be safely interpolated into templates
|
||||
TEMPLATE_PARAMS = %w[time site_url site_title site_description participants]
|
||||
|
||||
def lookup_template_param(key)
|
||||
public_send(key.to_sym) if TEMPLATE_PARAMS.include?(key)
|
||||
end
|
||||
|
||||
def time
|
||||
@time ||= Time.zone.now
|
||||
end
|
||||
|
||||
def site_url
|
||||
@site_url ||= Discourse.base_url
|
||||
end
|
||||
|
||||
def site_title
|
||||
@site_title ||= SiteSetting.title
|
||||
end
|
||||
|
||||
def site_description
|
||||
@site_description ||= SiteSetting.site_description
|
||||
end
|
||||
|
||||
def private_message?
|
||||
@private_message
|
||||
end
|
||||
|
||||
def to_json
|
||||
{
|
||||
messages: @messages,
|
||||
topic_id: @topic_id,
|
||||
post_id: @post_id,
|
||||
private_message: @private_message,
|
||||
custom_instructions: @custom_instructions,
|
||||
username: @user&.username,
|
||||
user_id: @user&.id,
|
||||
participants: @participants,
|
||||
chosen_tools: @chosen_tools,
|
||||
message_id: @message_id,
|
||||
channel_id: @channel_id,
|
||||
context_post_ids: @context_post_ids,
|
||||
site_url: @site_url,
|
||||
site_title: @site_title,
|
||||
site_description: @site_description,
|
||||
skip_tool_details: @skip_tool_details,
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -163,7 +163,7 @@ module DiscourseAi
|
||||
def craft_prompt(context, llm: nil)
|
||||
system_insts =
|
||||
system_prompt.gsub(/\{(\w+)\}/) do |match|
|
||||
found = context[match[1..-2].to_sym]
|
||||
found = context.lookup_template_param(match[1..-2])
|
||||
found.nil? ? match : found.to_s
|
||||
end
|
||||
|
||||
@ -180,16 +180,16 @@ module DiscourseAi
|
||||
)
|
||||
end
|
||||
|
||||
if context[:custom_instructions].present?
|
||||
if context.custom_instructions.present?
|
||||
prompt_insts << "\n"
|
||||
prompt_insts << context[:custom_instructions]
|
||||
prompt_insts << context.custom_instructions
|
||||
end
|
||||
|
||||
fragments_guidance =
|
||||
rag_fragments_prompt(
|
||||
context[:conversation_context].to_a,
|
||||
context.messages,
|
||||
llm: question_consolidator_llm,
|
||||
user: context[:user],
|
||||
user: context.user,
|
||||
)&.strip
|
||||
|
||||
prompt_insts << fragments_guidance if fragments_guidance.present?
|
||||
@ -197,9 +197,9 @@ module DiscourseAi
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
prompt_insts,
|
||||
messages: context[:conversation_context].to_a,
|
||||
topic_id: context[:topic_id],
|
||||
post_id: context[:post_id],
|
||||
messages: context.messages,
|
||||
topic_id: context.topic_id,
|
||||
post_id: context.post_id,
|
||||
)
|
||||
|
||||
prompt.max_pixels = self.class.vision_max_pixels if self.class.vision_enabled
|
||||
|
@ -227,90 +227,17 @@ module DiscourseAi
|
||||
schedule_bot_reply(post) if can_attach?(post)
|
||||
end
|
||||
|
||||
def conversation_context(post, style: nil)
|
||||
# Pay attention to the `post_number <= ?` here.
|
||||
# We want to inject the last post as context because they are translated differently.
|
||||
|
||||
# also setting default to 40, allowing huge contexts costs lots of tokens
|
||||
max_posts = 40
|
||||
if bot.persona.class.respond_to?(:max_context_posts)
|
||||
max_posts = bot.persona.class.max_context_posts || 40
|
||||
end
|
||||
|
||||
post_types = [Post.types[:regular]]
|
||||
post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
|
||||
|
||||
context =
|
||||
post
|
||||
.topic
|
||||
.posts
|
||||
.joins(:user)
|
||||
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
||||
.where("post_number <= ?", post.post_number)
|
||||
.order("post_number desc")
|
||||
.where("post_type in (?)", post_types)
|
||||
.limit(max_posts)
|
||||
.pluck(
|
||||
"posts.raw",
|
||||
"users.username",
|
||||
"post_custom_prompts.custom_prompt",
|
||||
"(
|
||||
SELECT array_agg(ref.upload_id)
|
||||
FROM upload_references ref
|
||||
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
|
||||
) as upload_ids",
|
||||
)
|
||||
|
||||
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
|
||||
builder.topic = post.topic
|
||||
|
||||
context.reverse_each do |raw, username, custom_prompt, upload_ids|
|
||||
custom_prompt_translation =
|
||||
Proc.new do |message|
|
||||
# We can't keep backwards-compatibility for stored functions.
|
||||
# Tool syntax requires a tool_call_id which we don't have.
|
||||
if message[2] != "function"
|
||||
custom_context = {
|
||||
content: message[0],
|
||||
type: message[2].present? ? message[2].to_sym : :model,
|
||||
}
|
||||
|
||||
custom_context[:id] = message[1] if custom_context[:type] != :model
|
||||
custom_context[:name] = message[3] if message[3]
|
||||
|
||||
thinking = message[4]
|
||||
custom_context[:thinking] = thinking if thinking
|
||||
|
||||
builder.push(**custom_context)
|
||||
end
|
||||
end
|
||||
|
||||
if custom_prompt.present?
|
||||
custom_prompt.each(&custom_prompt_translation)
|
||||
else
|
||||
context = {
|
||||
content: raw,
|
||||
type: (available_bot_usernames.include?(username) ? :model : :user),
|
||||
}
|
||||
|
||||
context[:id] = username if context[:type] == :user
|
||||
|
||||
if upload_ids.present? && context[:type] == :user && bot.persona.class.vision_enabled
|
||||
context[:upload_ids] = upload_ids.compact
|
||||
end
|
||||
|
||||
builder.push(**context)
|
||||
end
|
||||
end
|
||||
|
||||
builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
|
||||
end
|
||||
|
||||
def title_playground(post, user)
|
||||
context = conversation_context(post)
|
||||
messages =
|
||||
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
|
||||
post,
|
||||
max_posts: 5,
|
||||
bot_usernames: available_bot_usernames,
|
||||
include_uploads: bot.persona.class.vision_enabled,
|
||||
)
|
||||
|
||||
bot
|
||||
.get_updated_title(context, post, user)
|
||||
.get_updated_title(messages, post, user)
|
||||
.tap do |new_title|
|
||||
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
||||
bot.bot_user,
|
||||
@ -326,83 +253,6 @@ module DiscourseAi
|
||||
)
|
||||
end
|
||||
|
||||
def chat_context(message, channel, persona_user, context_post_ids)
|
||||
has_vision = bot.persona.class.vision_enabled
|
||||
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
|
||||
|
||||
current_id = message.id
|
||||
if !channel.direct_message_channel?
|
||||
# we are interacting via mentions ... strip mention
|
||||
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
|
||||
end
|
||||
|
||||
messages = nil
|
||||
|
||||
max_messages = 40
|
||||
if bot.persona.class.respond_to?(:max_context_posts)
|
||||
max_messages = bot.persona.class.max_context_posts || 40
|
||||
end
|
||||
|
||||
if !message.thread_id && channel.direct_message_channel?
|
||||
messages = [message]
|
||||
elsif !channel.direct_message_channel? && !message.thread_id
|
||||
messages =
|
||||
Chat::Message
|
||||
.joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
|
||||
.where(chat_channel_id: channel.id)
|
||||
.where(
|
||||
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
|
||||
)
|
||||
.order(id: :desc)
|
||||
.limit(max_messages)
|
||||
.to_a
|
||||
.reverse
|
||||
end
|
||||
|
||||
messages ||=
|
||||
ChatSDK::Thread.last_messages(
|
||||
thread_id: message.thread_id,
|
||||
guardian: Discourse.system_user.guardian,
|
||||
page_size: max_messages,
|
||||
)
|
||||
|
||||
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
|
||||
|
||||
guardian = Guardian.new(message.user)
|
||||
if context_post_ids
|
||||
builder.set_chat_context_posts(context_post_ids, guardian, include_uploads: has_vision)
|
||||
end
|
||||
|
||||
messages.each do |m|
|
||||
# restore stripped message
|
||||
m.message = instruction_message if m.id == current_id && instruction_message
|
||||
|
||||
if available_bot_user_ids.include?(m.user_id)
|
||||
builder.push(type: :model, content: m.message)
|
||||
else
|
||||
upload_ids = nil
|
||||
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
|
||||
mapped_message = m.message
|
||||
|
||||
thread_title = nil
|
||||
thread_title = m.thread&.title if include_thread_titles && m.thread_id
|
||||
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
|
||||
|
||||
builder.push(
|
||||
type: :user,
|
||||
content: mapped_message,
|
||||
name: m.user.username,
|
||||
upload_ids: upload_ids,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
builder.to_a(
|
||||
limit: max_messages,
|
||||
style: channel.direct_message_channel? ? :chat_with_context : :chat,
|
||||
)
|
||||
end
|
||||
|
||||
def reply_to_chat_message(message, channel, context_post_ids)
|
||||
persona_user = User.find(bot.persona.class.user_id)
|
||||
|
||||
@ -410,10 +260,32 @@ module DiscourseAi
|
||||
|
||||
context_post_ids = nil if !channel.direct_message_channel?
|
||||
|
||||
max_chat_messages = 40
|
||||
if bot.persona.class.respond_to?(:max_context_posts)
|
||||
max_chat_messages = bot.persona.class.max_context_posts || 40
|
||||
end
|
||||
|
||||
if !channel.direct_message_channel?
|
||||
# we are interacting via mentions ... strip mention
|
||||
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
|
||||
end
|
||||
|
||||
context =
|
||||
get_context(
|
||||
participants: participants.join(", "),
|
||||
conversation_context: chat_context(message, channel, persona_user, context_post_ids),
|
||||
BotContext.new(
|
||||
participants: participants,
|
||||
message_id: message.id,
|
||||
channel_id: channel.id,
|
||||
context_post_ids: context_post_ids,
|
||||
messages:
|
||||
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_chat(
|
||||
message,
|
||||
channel: channel,
|
||||
context_post_ids: context_post_ids,
|
||||
include_uploads: bot.persona.class.vision_enabled,
|
||||
max_messages: max_chat_messages,
|
||||
bot_user_ids: available_bot_user_ids,
|
||||
instruction_message: instruction_message,
|
||||
),
|
||||
user: message.user,
|
||||
skip_tool_details: true,
|
||||
)
|
||||
@ -460,22 +332,6 @@ module DiscourseAi
|
||||
reply
|
||||
end
|
||||
|
||||
def get_context(participants:, conversation_context:, user:, skip_tool_details: nil)
|
||||
result = {
|
||||
site_url: Discourse.base_url,
|
||||
site_title: SiteSetting.title,
|
||||
site_description: SiteSetting.site_description,
|
||||
time: Time.zone.now,
|
||||
participants: participants,
|
||||
conversation_context: conversation_context,
|
||||
user: user,
|
||||
}
|
||||
|
||||
result[:skip_tool_details] = true if skip_tool_details
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def reply_to(
|
||||
post,
|
||||
custom_instructions: nil,
|
||||
@ -509,16 +365,25 @@ module DiscourseAi
|
||||
end
|
||||
)
|
||||
|
||||
# safeguard
|
||||
max_context_posts = 40
|
||||
if bot.persona.class.respond_to?(:max_context_posts)
|
||||
max_context_posts = bot.persona.class.max_context_posts || 40
|
||||
end
|
||||
|
||||
context =
|
||||
get_context(
|
||||
participants: post.topic.allowed_users.map(&:username).join(", "),
|
||||
conversation_context: conversation_context(post, style: context_style),
|
||||
user: post.user,
|
||||
BotContext.new(
|
||||
post: post,
|
||||
custom_instructions: custom_instructions,
|
||||
messages:
|
||||
DiscourseAi::Completions::PromptMessagesBuilder.messages_from_post(
|
||||
post,
|
||||
style: context_style,
|
||||
max_posts: max_context_posts,
|
||||
include_uploads: bot.persona.class.vision_enabled,
|
||||
bot_usernames: available_bot_usernames,
|
||||
),
|
||||
)
|
||||
context[:post_id] = post.id
|
||||
context[:topic_id] = post.topic_id
|
||||
context[:private_message] = post.topic.private_message?
|
||||
context[:custom_instructions] = custom_instructions
|
||||
|
||||
reply_user = bot.bot_user
|
||||
if bot.persona.class.respond_to?(:user_id)
|
||||
@ -562,7 +427,7 @@ module DiscourseAi
|
||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||
end
|
||||
|
||||
context[:skip_tool_details] ||= !bot.persona.class.tool_details
|
||||
context.skip_tool_details ||= !bot.persona.class.tool_details
|
||||
|
||||
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
|
||||
|
||||
|
@ -13,7 +13,13 @@ module DiscourseAi
|
||||
MARSHAL_STACK_DEPTH = 20
|
||||
MAX_HTTP_REQUESTS = 20
|
||||
|
||||
def initialize(parameters:, llm:, bot_user:, context: {}, tool:, timeout: nil)
|
||||
def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
|
||||
if context && !context.is_a?(DiscourseAi::AiBot::BotContext)
|
||||
raise ArgumentError, "context must be a BotContext object"
|
||||
end
|
||||
|
||||
context ||= DiscourseAi::AiBot::BotContext.new
|
||||
|
||||
@parameters = parameters
|
||||
@llm = llm
|
||||
@bot_user = bot_user
|
||||
@ -99,7 +105,7 @@ module DiscourseAi
|
||||
},
|
||||
};
|
||||
|
||||
const context = #{JSON.generate(@context)};
|
||||
const context = #{JSON.generate(@context.to_json)};
|
||||
|
||||
function details() { return ""; };
|
||||
JS
|
||||
@ -240,13 +246,13 @@ module DiscourseAi
|
||||
def llm_user
|
||||
@llm_user ||=
|
||||
begin
|
||||
@context[:llm_user] || post&.user || @bot_user
|
||||
post&.user || @bot_user
|
||||
end
|
||||
end
|
||||
|
||||
def post
|
||||
return @post if defined?(@post)
|
||||
post_id = @context[:post_id]
|
||||
post_id = @context.post_id
|
||||
@post = post_id && Post.find_by(id: post_id)
|
||||
end
|
||||
|
||||
@ -336,8 +342,8 @@ module DiscourseAi
|
||||
bot = DiscourseAi::AiBot::Bot.as(@bot_user || persona.user, persona: persona)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
if @context[:post_id]
|
||||
post = Post.find_by(id: @context[:post_id])
|
||||
if @context.post_id
|
||||
post = Post.find_by(id: @context.post_id)
|
||||
return { error: "Post not found" } if post.nil?
|
||||
|
||||
reply_post =
|
||||
@ -354,13 +360,13 @@ module DiscourseAi
|
||||
else
|
||||
return { error: "Failed to create reply" }
|
||||
end
|
||||
elsif @context[:message_id] && @context[:channel_id]
|
||||
message = Chat::Message.find_by(id: @context[:message_id])
|
||||
channel = Chat::Channel.find_by(id: @context[:channel_id])
|
||||
elsif @context.message_id && @context.channel_id
|
||||
message = Chat::Message.find_by(id: @context.message_id)
|
||||
channel = Chat::Channel.find_by(id: @context.channel_id)
|
||||
return { error: "Message or channel not found" } if message.nil? || channel.nil?
|
||||
|
||||
reply =
|
||||
playground.reply_to_chat_message(message, channel, @context[:context_post_ids])
|
||||
playground.reply_to_chat_message(message, channel, @context.context_post_ids)
|
||||
|
||||
if reply
|
||||
return { success: true, message_id: reply.id }
|
||||
@ -457,7 +463,7 @@ module DiscourseAi
|
||||
UploadCreator.new(
|
||||
file,
|
||||
filename,
|
||||
for_private_message: @context[:private_message],
|
||||
for_private_message: @context.private_message,
|
||||
).create_for(@bot_user.id)
|
||||
|
||||
{ id: upload.id, short_url: upload.short_url, url: upload.url }
|
||||
|
@ -108,7 +108,7 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def invoke
|
||||
post = Post.find_by(id: context[:post_id])
|
||||
post = Post.find_by(id: context.post_id)
|
||||
return error_response("No post context found") unless post
|
||||
|
||||
partial_response = +""
|
||||
|
@ -111,7 +111,7 @@ module DiscourseAi
|
||||
UploadCreator.new(
|
||||
file,
|
||||
"image.png",
|
||||
for_private_message: context[:private_message],
|
||||
for_private_message: context.private_message?,
|
||||
).create_for(bot_user.id),
|
||||
}
|
||||
end
|
||||
|
@ -131,7 +131,7 @@ module DiscourseAi
|
||||
UploadCreator.new(
|
||||
file,
|
||||
"image.png",
|
||||
for_private_message: context[:private_message],
|
||||
for_private_message: context.private_message,
|
||||
).create_for(bot_user.id),
|
||||
seed: image[:seed],
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ module DiscourseAi
|
||||
|
||||
def invoke
|
||||
not_found = { topic_id: topic_id, description: "Topic not found" }
|
||||
guardian = Guardian.new(context[:user]) if options[:read_private] && context[:user]
|
||||
guardian = Guardian.new(context.user) if options[:read_private] && context.user
|
||||
guardian ||= Guardian.new
|
||||
|
||||
@title = ""
|
||||
|
@ -59,7 +59,7 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def post
|
||||
@post ||= Post.find_by(id: context[:post_id])
|
||||
@post ||= Post.find_by(id: context.post_id)
|
||||
end
|
||||
|
||||
def handle_discourse_artifact(uri)
|
||||
|
@ -130,7 +130,7 @@ module DiscourseAi
|
||||
after: parameters[:after],
|
||||
status: parameters[:status],
|
||||
max_results: max_results,
|
||||
current_user: options[:search_private] ? context[:user] : nil,
|
||||
current_user: options[:search_private] ? context.user : nil,
|
||||
)
|
||||
|
||||
@last_num_results = results[:rows]&.length || 0
|
||||
|
@ -40,10 +40,6 @@ module DiscourseAi
|
||||
false
|
||||
end
|
||||
|
||||
def standalone?
|
||||
true
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
||||
end
|
||||
|
@ -56,14 +56,17 @@ module DiscourseAi
|
||||
persona_options: {},
|
||||
bot_user:,
|
||||
llm:,
|
||||
context: {}
|
||||
context: nil
|
||||
)
|
||||
@parameters = parameters
|
||||
@tool_call_id = tool_call_id
|
||||
@persona_options = persona_options
|
||||
@bot_user = bot_user
|
||||
@llm = llm
|
||||
@context = context
|
||||
@context = context.nil? ? DiscourseAi::AiBot::BotContext.new(messages: []) : context
|
||||
if !@context.is_a?(DiscourseAi::AiBot::BotContext)
|
||||
raise ArgumentError, "context must be a DiscourseAi::AiBot::Context"
|
||||
end
|
||||
end
|
||||
|
||||
def name
|
||||
@ -108,10 +111,6 @@ module DiscourseAi
|
||||
true
|
||||
end
|
||||
|
||||
def standalone?
|
||||
false
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def fetch_default_branch(repo)
|
||||
|
@ -39,7 +39,7 @@ module DiscourseAi
|
||||
def self.inject_prompt(prompt:, context:, persona:)
|
||||
return if persona.options["do_not_echo_artifact"].to_s == "true"
|
||||
# we inject the current artifact content into the last user message
|
||||
if topic_id = context[:topic_id]
|
||||
if topic_id = context.topic_id
|
||||
posts = Post.where(topic_id: topic_id)
|
||||
artifact = AiArtifact.order("id desc").where(post: posts).first
|
||||
if artifact
|
||||
@ -113,7 +113,7 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def invoke
|
||||
post = Post.find_by(id: context[:post_id])
|
||||
post = Post.find_by(id: context.post_id)
|
||||
return error_response("No post context found") unless post
|
||||
|
||||
artifact = AiArtifact.find_by(id: parameters[:artifact_id])
|
||||
|
@ -188,9 +188,10 @@ module DiscourseAi
|
||||
messages: [
|
||||
{
|
||||
type: :user,
|
||||
content:
|
||||
content: [
|
||||
"Describe this image in a single sentence#{custom_locale_instructions(user)}",
|
||||
upload_ids: [upload.id],
|
||||
{ upload_id: upload.id },
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
@ -187,7 +187,10 @@ module DiscourseAi
|
||||
prompt = DiscourseAi::Completions::Prompt.new(system_prompt)
|
||||
args = { type: :user, content: context }
|
||||
upload_ids = post.upload_ids
|
||||
args[:upload_ids] = upload_ids.take(3) if upload_ids.present?
|
||||
if upload_ids.present?
|
||||
args[:content] = [args[:content]]
|
||||
upload_ids.take(3).each { |upload_id| args[:content] << { upload_id: upload_id } }
|
||||
end
|
||||
prompt.push(**args)
|
||||
prompt
|
||||
end
|
||||
|
@ -7,11 +7,7 @@ module DiscourseAi
|
||||
return if !tool
|
||||
return if !tool.parameters.blank?
|
||||
|
||||
context = {
|
||||
post_id: post.id,
|
||||
automation_id: automation&.id,
|
||||
automation_name: automation&.name,
|
||||
}
|
||||
context = DiscourseAi::AiBot::BotContext.new(post: post)
|
||||
|
||||
runner = tool.runner({}, llm: nil, bot_user: Discourse.system_user, context: context)
|
||||
runner.invoke
|
||||
|
@ -42,7 +42,12 @@ module DiscourseAi
|
||||
|
||||
content = llm.tokenizer.truncate(content, max_post_tokens) if max_post_tokens.present?
|
||||
|
||||
prompt.push(type: :user, content: content, upload_ids: post.upload_ids)
|
||||
if post.upload_ids.present?
|
||||
content = [content]
|
||||
content.concat(post.upload_ids.map { |upload_id| { upload_id: upload_id } })
|
||||
end
|
||||
|
||||
prompt.push(type: :user, content: content)
|
||||
|
||||
result = nil
|
||||
|
||||
|
@ -17,13 +17,13 @@ module DiscourseAi
|
||||
llm_model.provider == "open_ai" || llm_model.provider == "azure"
|
||||
end
|
||||
|
||||
def translate
|
||||
def embed_user_ids?
|
||||
return @embed_user_ids if defined?(@embed_user_ids)
|
||||
|
||||
@embed_user_ids =
|
||||
prompt.messages.any? do |m|
|
||||
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
|
||||
end
|
||||
|
||||
super
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
@ -102,35 +102,47 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
user_message = { role: "user", content: msg[:content] }
|
||||
content_array = []
|
||||
|
||||
user_message = { role: "user" }
|
||||
|
||||
if msg[:id]
|
||||
if @embed_user_ids
|
||||
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
|
||||
if embed_user_ids?
|
||||
content_array << "#{msg[:id]}: "
|
||||
else
|
||||
user_message[:name] = msg[:id]
|
||||
end
|
||||
end
|
||||
|
||||
user_message[:content] = inline_images(user_message[:content], msg) if vision_support?
|
||||
content_array << msg[:content]
|
||||
|
||||
content_array =
|
||||
to_encoded_content_array(
|
||||
content: content_array.flatten,
|
||||
image_encoder: ->(details) { image_node(details) },
|
||||
text_encoder: ->(text) { { type: "text", text: text } },
|
||||
allow_vision: vision_support?,
|
||||
)
|
||||
|
||||
user_message[:content] = no_array_if_only_text(content_array)
|
||||
user_message
|
||||
end
|
||||
|
||||
def inline_images(content, message)
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
return content if encoded_uploads.blank?
|
||||
def no_array_if_only_text(content_array)
|
||||
if content_array.size == 1 && content_array.first[:type] == "text"
|
||||
content_array.first[:text]
|
||||
else
|
||||
content_array
|
||||
end
|
||||
end
|
||||
|
||||
content_w_imgs =
|
||||
encoded_uploads.reduce([]) do |memo, details|
|
||||
memo << {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
content_w_imgs << { type: "text", text: message[:content] }
|
||||
def image_node(details)
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def per_message_overhead
|
||||
|
@ -87,9 +87,9 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
if msg[:thinking] || msg[:redacted_thinking_signature]
|
||||
content_array = []
|
||||
content_array = []
|
||||
|
||||
if msg[:thinking] || msg[:redacted_thinking_signature]
|
||||
if msg[:thinking]
|
||||
content_array << {
|
||||
type: "thinking",
|
||||
@ -104,13 +104,19 @@ module DiscourseAi
|
||||
data: msg[:redacted_thinking_signature],
|
||||
}
|
||||
end
|
||||
|
||||
content_array << { type: "text", text: msg[:content] }
|
||||
|
||||
{ role: "assistant", content: content_array }
|
||||
else
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
# other encoder is used to pass through thinking
|
||||
content_array =
|
||||
to_encoded_content_array(
|
||||
content: [content_array, msg[:content]].flatten,
|
||||
image_encoder: ->(details) {},
|
||||
text_encoder: ->(text) { { type: "text", text: text } },
|
||||
other_encoder: ->(details) { details },
|
||||
allow_vision: false,
|
||||
)
|
||||
|
||||
{ role: "assistant", content: no_array_if_only_text(content_array) }
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
@ -124,31 +130,39 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
content = +""
|
||||
content << "#{msg[:id]}: " if msg[:id]
|
||||
content << msg[:content]
|
||||
content = inline_images(content, msg) if vision_support?
|
||||
content_array = []
|
||||
content_array << "#{msg[:id]}: " if msg[:id]
|
||||
content_array.concat([msg[:content]].flatten)
|
||||
|
||||
{ role: "user", content: content }
|
||||
content_array =
|
||||
to_encoded_content_array(
|
||||
content: content_array,
|
||||
image_encoder: ->(details) { image_node(details) },
|
||||
text_encoder: ->(text) { { type: "text", text: text } },
|
||||
allow_vision: vision_support?,
|
||||
)
|
||||
|
||||
{ role: "user", content: no_array_if_only_text(content_array) }
|
||||
end
|
||||
|
||||
def inline_images(content, message)
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
return content if encoded_uploads.blank?
|
||||
# keeping our payload as backward compatible as possible
|
||||
def no_array_if_only_text(content_array)
|
||||
if content_array.length == 1 && content_array.first[:type] == "text"
|
||||
content_array.first[:text]
|
||||
else
|
||||
content_array
|
||||
end
|
||||
end
|
||||
|
||||
content_w_imgs =
|
||||
encoded_uploads.reduce([]) do |memo, details|
|
||||
memo << {
|
||||
source: {
|
||||
type: "base64",
|
||||
data: details[:base64],
|
||||
media_type: details[:mime_type],
|
||||
},
|
||||
type: "image",
|
||||
}
|
||||
end
|
||||
|
||||
content_w_imgs << { type: "text", text: content }
|
||||
def image_node(details)
|
||||
{
|
||||
source: {
|
||||
type: "base64",
|
||||
data: details[:base64],
|
||||
media_type: details[:mime_type],
|
||||
},
|
||||
type: "image",
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -110,9 +110,9 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
user_message = { role: "USER", message: msg[:content] }
|
||||
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
|
||||
|
||||
content = prompt.text_only(msg)
|
||||
user_message = { role: "USER", message: content }
|
||||
user_message[:message] = "#{msg[:id]}: #{content}" if msg[:id]
|
||||
user_message
|
||||
end
|
||||
end
|
||||
|
@ -227,6 +227,38 @@ module DiscourseAi
|
||||
msg = msg.merge(content: new_content)
|
||||
user_msg(msg)
|
||||
end
|
||||
|
||||
def to_encoded_content_array(
|
||||
content:,
|
||||
image_encoder:,
|
||||
text_encoder:,
|
||||
other_encoder: nil,
|
||||
allow_vision:
|
||||
)
|
||||
content = [content] if !content.is_a?(Array)
|
||||
|
||||
current_string = +""
|
||||
result = []
|
||||
|
||||
content.each do |c|
|
||||
if c.is_a?(String)
|
||||
current_string << c
|
||||
elsif c.is_a?(Hash) && c.key?(:upload_id) && allow_vision
|
||||
if !current_string.empty?
|
||||
result << text_encoder.call(current_string)
|
||||
current_string = +""
|
||||
end
|
||||
encoded = prompt.encode_upload(c[:upload_id])
|
||||
result << image_encoder.call(encoded) if encoded
|
||||
elsif other_encoder
|
||||
encoded = other_encoder.call(c)
|
||||
result << encoded if encoded
|
||||
end
|
||||
end
|
||||
|
||||
result << text_encoder.call(current_string) if !current_string.empty?
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -106,28 +106,29 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
if beta_api?
|
||||
# support new format with multiple parts
|
||||
result = { role: "user", parts: [{ text: msg[:content] }] }
|
||||
return result unless vision_support?
|
||||
content_array = []
|
||||
content_array << "#{msg[:id]}: " if msg[:id]
|
||||
|
||||
upload_parts = uploaded_parts(msg)
|
||||
result[:parts].concat(upload_parts) if upload_parts.present?
|
||||
result
|
||||
content_array << msg[:content]
|
||||
content_array.flatten!
|
||||
|
||||
content_array =
|
||||
to_encoded_content_array(
|
||||
content: content_array,
|
||||
image_encoder: ->(details) { image_node(details) },
|
||||
text_encoder: ->(text) { { text: text } },
|
||||
allow_vision: vision_support? && beta_api?,
|
||||
)
|
||||
|
||||
if beta_api?
|
||||
{ role: "user", parts: content_array }
|
||||
else
|
||||
{ role: "user", parts: { text: msg[:content] } }
|
||||
{ role: "user", parts: content_array.first }
|
||||
end
|
||||
end
|
||||
|
||||
def uploaded_parts(message)
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
result = []
|
||||
if encoded_uploads.present?
|
||||
encoded_uploads.each do |details|
|
||||
result << { inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
|
||||
end
|
||||
end
|
||||
result
|
||||
def image_node(details)
|
||||
{ inlineData: { mimeType: details[:mime_type], data: details[:base64] } }
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
|
@ -155,7 +155,7 @@ module DiscourseAi
|
||||
end
|
||||
end
|
||||
|
||||
{ role: "user", content: msg[:content], images: images }
|
||||
{ role: "user", content: prompt.text_only(msg), images: images }
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
|
@ -69,7 +69,7 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
user_message = { role: "user", content: msg[:content] }
|
||||
user_message = { role: "user", content: prompt.text_only(msg) }
|
||||
|
||||
encoded_uploads = prompt.encoded_uploads(msg)
|
||||
if encoded_uploads.present?
|
||||
|
@ -3,7 +3,7 @@
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class OpenAiCompatible < Dialect
|
||||
class OpenAiCompatible < ChatGpt
|
||||
class << self
|
||||
def can_translate?(_llm_model)
|
||||
# fallback dialect
|
||||
@ -43,58 +43,6 @@ module DiscourseAi
|
||||
|
||||
translated.unshift(user_msg)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def system_msg(msg)
|
||||
msg = { role: "system", content: msg[:content] }
|
||||
|
||||
if tools_dialect.instructions.present?
|
||||
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
||||
end
|
||||
|
||||
msg
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
translated = tools_dialect.from_raw_tool_call(msg)
|
||||
{ role: "assistant", content: translated }
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
translated = tools_dialect.from_raw_tool(msg)
|
||||
{ role: "user", content: translated }
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
content = +""
|
||||
content << "#{msg[:id]}: " if msg[:id]
|
||||
content << msg[:content]
|
||||
|
||||
message = { role: "user", content: content }
|
||||
|
||||
message[:content] = inline_images(message[:content], msg) if vision_support?
|
||||
|
||||
message
|
||||
end
|
||||
|
||||
def inline_images(content, message)
|
||||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
return content if encoded_uploads.blank?
|
||||
|
||||
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
||||
memo << {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -89,7 +89,6 @@ module DiscourseAi
|
||||
content:,
|
||||
id: nil,
|
||||
name: nil,
|
||||
upload_ids: nil,
|
||||
thinking: nil,
|
||||
thinking_signature: nil,
|
||||
redacted_thinking_signature: nil
|
||||
@ -98,7 +97,6 @@ module DiscourseAi
|
||||
new_message = { type: type, content: content }
|
||||
new_message[:name] = name.to_s if name
|
||||
new_message[:id] = id.to_s if id
|
||||
new_message[:upload_ids] = upload_ids if upload_ids
|
||||
new_message[:thinking] = thinking if thinking
|
||||
new_message[:thinking_signature] = thinking_signature if thinking_signature
|
||||
new_message[
|
||||
@ -115,11 +113,44 @@ module DiscourseAi
|
||||
tools.present?
|
||||
end
|
||||
|
||||
# helper method to get base64 encoded uploads
|
||||
# at the correct dimentions
|
||||
def encoded_uploads(message)
|
||||
return [] if message[:upload_ids].blank?
|
||||
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
|
||||
if message[:content].is_a?(Array)
|
||||
upload_ids =
|
||||
message[:content]
|
||||
.map do |content|
|
||||
content[:upload_id] if content.is_a?(Hash) && content.key?(:upload_id)
|
||||
end
|
||||
.compact
|
||||
if !upload_ids.empty?
|
||||
return UploadEncoder.encode(upload_ids: upload_ids, max_pixels: max_pixels)
|
||||
end
|
||||
end
|
||||
|
||||
[]
|
||||
end
|
||||
|
||||
def text_only(message)
|
||||
if message[:content].is_a?(Array)
|
||||
message[:content].map { |element| element if element.is_a?(String) }.compact.join
|
||||
else
|
||||
message[:content]
|
||||
end
|
||||
end
|
||||
|
||||
def encode_upload(upload_id)
|
||||
UploadEncoder.encode(upload_ids: [upload_id], max_pixels: max_pixels).first
|
||||
end
|
||||
|
||||
def content_with_encoded_uploads(content)
|
||||
return [content] unless content.is_a?(Array)
|
||||
|
||||
content.map do |c|
|
||||
if c.is_a?(Hash) && c.key?(:upload_id)
|
||||
encode_upload(c[:upload_id])
|
||||
else
|
||||
c
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def ==(other)
|
||||
@ -150,7 +181,6 @@ module DiscourseAi
|
||||
content
|
||||
id
|
||||
name
|
||||
upload_ids
|
||||
thinking
|
||||
thinking_signature
|
||||
redacted_thinking_signature
|
||||
@ -159,15 +189,17 @@ module DiscourseAi
|
||||
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
|
||||
end
|
||||
|
||||
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
|
||||
raise ArgumentError, "upload_ids must be an array of ids"
|
||||
if message[:content].is_a?(Array)
|
||||
message[:content].each do |content|
|
||||
if !content.is_a?(String) && !(content.is_a?(Hash) && content.keys == [:upload_id])
|
||||
raise ArgumentError, "Array message content must be a string or {upload_id: ...} "
|
||||
end
|
||||
end
|
||||
else
|
||||
if !message[:content].is_a?(String)
|
||||
raise ArgumentError, "Message content must be a string or an array"
|
||||
end
|
||||
end
|
||||
|
||||
if message[:upload_ids].present? && message[:type] != :user
|
||||
raise ArgumentError, "upload_ids are only supported for users"
|
||||
end
|
||||
|
||||
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
|
||||
end
|
||||
|
||||
def validate_turn(last_turn, new_turn)
|
||||
|
@ -9,6 +9,154 @@ module DiscourseAi
|
||||
attr_reader :chat_context_post_upload_ids
|
||||
attr_accessor :topic
|
||||
|
||||
def self.messages_from_chat(
|
||||
message,
|
||||
channel:,
|
||||
context_post_ids:,
|
||||
max_messages:,
|
||||
include_uploads:,
|
||||
bot_user_ids:,
|
||||
instruction_message: nil
|
||||
)
|
||||
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
|
||||
|
||||
current_id = message.id
|
||||
messages = nil
|
||||
|
||||
if !message.thread_id && channel.direct_message_channel?
|
||||
messages = [message]
|
||||
elsif !channel.direct_message_channel? && !message.thread_id
|
||||
messages =
|
||||
Chat::Message
|
||||
.joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
|
||||
.where(chat_channel_id: channel.id)
|
||||
.where(
|
||||
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
|
||||
)
|
||||
.order(id: :desc)
|
||||
.limit(max_messages)
|
||||
.to_a
|
||||
.reverse
|
||||
end
|
||||
|
||||
messages ||=
|
||||
ChatSDK::Thread.last_messages(
|
||||
thread_id: message.thread_id,
|
||||
guardian: Discourse.system_user.guardian,
|
||||
page_size: max_messages,
|
||||
)
|
||||
|
||||
builder = new
|
||||
|
||||
guardian = Guardian.new(message.user)
|
||||
if context_post_ids
|
||||
builder.set_chat_context_posts(
|
||||
context_post_ids,
|
||||
guardian,
|
||||
include_uploads: include_uploads,
|
||||
)
|
||||
end
|
||||
|
||||
messages.each do |m|
|
||||
# restore stripped message
|
||||
m.message = instruction_message if m.id == current_id && instruction_message
|
||||
|
||||
if bot_user_ids.include?(m.user_id)
|
||||
builder.push(type: :model, content: m.message)
|
||||
else
|
||||
upload_ids = nil
|
||||
upload_ids = m.uploads.map(&:id) if include_uploads && m.uploads.present?
|
||||
mapped_message = m.message
|
||||
|
||||
thread_title = nil
|
||||
thread_title = m.thread&.title if include_thread_titles && m.thread_id
|
||||
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
|
||||
|
||||
builder.push(
|
||||
type: :user,
|
||||
content: mapped_message,
|
||||
name: m.user.username,
|
||||
upload_ids: upload_ids,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
builder.to_a(
|
||||
limit: max_messages,
|
||||
style: channel.direct_message_channel? ? :chat_with_context : :chat,
|
||||
)
|
||||
end
|
||||
|
||||
def self.messages_from_post(post, style: nil, max_posts:, bot_usernames:, include_uploads:)
|
||||
# Pay attention to the `post_number <= ?` here.
|
||||
# We want to inject the last post as context because they are translated differently.
|
||||
|
||||
post_types = [Post.types[:regular]]
|
||||
post_types << Post.types[:whisper] if post.post_type == Post.types[:whisper]
|
||||
|
||||
context =
|
||||
post
|
||||
.topic
|
||||
.posts
|
||||
.joins(:user)
|
||||
.joins("LEFT JOIN post_custom_prompts ON post_custom_prompts.post_id = posts.id")
|
||||
.where("post_number <= ?", post.post_number)
|
||||
.order("post_number desc")
|
||||
.where("post_type in (?)", post_types)
|
||||
.limit(max_posts)
|
||||
.pluck(
|
||||
"posts.raw",
|
||||
"users.username",
|
||||
"post_custom_prompts.custom_prompt",
|
||||
"(
|
||||
SELECT array_agg(ref.upload_id)
|
||||
FROM upload_references ref
|
||||
WHERE ref.target_type = 'Post' AND ref.target_id = posts.id
|
||||
) as upload_ids",
|
||||
)
|
||||
|
||||
builder = new
|
||||
builder.topic = post.topic
|
||||
|
||||
context.reverse_each do |raw, username, custom_prompt, upload_ids|
|
||||
custom_prompt_translation =
|
||||
Proc.new do |message|
|
||||
# We can't keep backwards-compatibility for stored functions.
|
||||
# Tool syntax requires a tool_call_id which we don't have.
|
||||
if message[2] != "function"
|
||||
custom_context = {
|
||||
content: message[0],
|
||||
type: message[2].present? ? message[2].to_sym : :model,
|
||||
}
|
||||
|
||||
custom_context[:id] = message[1] if custom_context[:type] != :model
|
||||
custom_context[:name] = message[3] if message[3]
|
||||
|
||||
thinking = message[4]
|
||||
custom_context[:thinking] = thinking if thinking
|
||||
|
||||
builder.push(**custom_context)
|
||||
end
|
||||
end
|
||||
|
||||
if custom_prompt.present?
|
||||
custom_prompt.each(&custom_prompt_translation)
|
||||
else
|
||||
context = { content: raw, type: (bot_usernames.include?(username) ? :model : :user) }
|
||||
|
||||
context[:id] = username if context[:type] == :user
|
||||
|
||||
if upload_ids.present? && context[:type] == :user && include_uploads
|
||||
context[:upload_ids] = upload_ids.compact
|
||||
end
|
||||
|
||||
builder.push(**context)
|
||||
end
|
||||
end
|
||||
|
||||
builder.to_a(style: style || (post.topic.private_message? ? :bot : :topic))
|
||||
end
|
||||
|
||||
def initialize
|
||||
@raw_messages = []
|
||||
end
|
||||
@ -68,12 +216,19 @@ module DiscourseAi
|
||||
|
||||
if message[:type] == :user
|
||||
old_name = last_message.delete(:name)
|
||||
last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name
|
||||
last_message[:content] = ["#{old_name}: ", last_message[:content]].flatten if old_name
|
||||
|
||||
new_content = message[:content]
|
||||
new_content = "#{message[:name]}: #{new_content}" if message[:name]
|
||||
new_content = ["#{message[:name]}: ", new_content].flatten if message[:name]
|
||||
|
||||
last_message[:content] += "\n#{new_content}"
|
||||
if !last_message[:content].is_a?(Array)
|
||||
last_message[:content] = [last_message[:content]]
|
||||
end
|
||||
last_message[:content].concat(["\n", new_content].flatten)
|
||||
|
||||
compressed =
|
||||
compress_messages_buffer(last_message[:content], max_uploads: MAX_TOPIC_UPLOADS)
|
||||
last_message[:content] = compressed
|
||||
else
|
||||
last_message[:content] = message[:content]
|
||||
end
|
||||
@ -111,9 +266,9 @@ module DiscourseAi
|
||||
end
|
||||
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
|
||||
|
||||
content = [content, *upload_ids.map { |upload_id| { upload_id: upload_id } }] if upload_ids
|
||||
message = { type: type, content: content }
|
||||
message[:name] = name.to_s if name
|
||||
message[:upload_ids] = upload_ids if upload_ids
|
||||
message[:id] = id.to_s if id
|
||||
if thinking
|
||||
message[:thinking] = thinking["thinking"] if thinking["thinking"]
|
||||
@ -132,67 +287,62 @@ module DiscourseAi
|
||||
|
||||
def topic_array
|
||||
raw_messages = @raw_messages.dup
|
||||
user_content = +"You are operating in a Discourse forum.\n\n"
|
||||
content_array = []
|
||||
content_array << "You are operating in a Discourse forum.\n\n"
|
||||
|
||||
if @topic
|
||||
if @topic.private_message?
|
||||
user_content << "Private message info.\n"
|
||||
content_array << "Private message info.\n"
|
||||
else
|
||||
user_content << "Topic information:\n"
|
||||
content_array << "Topic information:\n"
|
||||
end
|
||||
|
||||
user_content << "- URL: #{@topic.url}\n"
|
||||
user_content << "- Title: #{@topic.title}\n"
|
||||
content_array << "- URL: #{@topic.url}\n"
|
||||
content_array << "- Title: #{@topic.title}\n"
|
||||
if SiteSetting.tagging_enabled
|
||||
tags = @topic.tags.pluck(:name)
|
||||
tags -= DiscourseTagging.hidden_tag_names if tags.present?
|
||||
user_content << "- Tags: #{tags.join(", ")}\n" if tags.present?
|
||||
content_array << "- Tags: #{tags.join(", ")}\n" if tags.present?
|
||||
end
|
||||
if !@topic.private_message?
|
||||
user_content << "- Category: #{@topic.category.name}\n" if @topic.category
|
||||
content_array << "- Category: #{@topic.category.name}\n" if @topic.category
|
||||
end
|
||||
user_content << "- Number of replies: #{@topic.posts_count - 1}\n\n"
|
||||
content_array << "- Number of replies: #{@topic.posts_count - 1}\n\n"
|
||||
end
|
||||
|
||||
last_user_message = raw_messages.pop
|
||||
|
||||
upload_ids = []
|
||||
if raw_messages.present?
|
||||
user_content << "Here is the conversation so far:\n"
|
||||
content_array << "Here is the conversation so far:\n"
|
||||
raw_messages.each do |message|
|
||||
user_content << "#{message[:name] || "User"}: #{message[:content]}\n"
|
||||
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
|
||||
content_array << "#{message[:name] || "User"}: "
|
||||
content_array << message[:content]
|
||||
content_array << "\n\n"
|
||||
end
|
||||
end
|
||||
|
||||
if last_user_message
|
||||
user_content << "You are responding to #{last_user_message[:name] || "User"} who just said:\n #{last_user_message[:content]}"
|
||||
if last_user_message[:upload_ids].present?
|
||||
upload_ids.concat(last_user_message[:upload_ids])
|
||||
end
|
||||
content_array << "You are responding to #{last_user_message[:name] || "User"} who just said:\n"
|
||||
content_array << last_user_message[:content]
|
||||
end
|
||||
|
||||
user_message = { type: :user, content: user_content }
|
||||
content_array =
|
||||
compress_messages_buffer(content_array.flatten, max_uploads: MAX_TOPIC_UPLOADS)
|
||||
|
||||
if upload_ids.present?
|
||||
user_message[:upload_ids] = upload_ids[-MAX_TOPIC_UPLOADS..-1] || upload_ids
|
||||
end
|
||||
user_message = { type: :user, content: content_array }
|
||||
|
||||
[user_message]
|
||||
end
|
||||
|
||||
def chat_array(limit:)
|
||||
if @raw_messages.length > 1
|
||||
buffer =
|
||||
+"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{"
|
||||
|
||||
upload_ids = []
|
||||
buffer = [
|
||||
+"You are replying inside a Discourse chat channel. Here is a summary of the conversation so far:\n{{{",
|
||||
]
|
||||
|
||||
@raw_messages[0..-2].each do |message|
|
||||
buffer << "\n"
|
||||
|
||||
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
|
||||
|
||||
if message[:type] == :user
|
||||
buffer << "#{message[:name] || "User"}: "
|
||||
else
|
||||
@ -209,16 +359,44 @@ module DiscourseAi
|
||||
end
|
||||
|
||||
last_message = @raw_messages[-1]
|
||||
buffer << "#{last_message[:name] || "User"}: #{last_message[:content]} "
|
||||
buffer << "#{last_message[:name] || "User"}: "
|
||||
buffer << last_message[:content]
|
||||
|
||||
buffer = compress_messages_buffer(buffer.flatten, max_uploads: MAX_CHAT_UPLOADS)
|
||||
|
||||
message = { type: :user, content: buffer }
|
||||
upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present?
|
||||
|
||||
message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] ||
|
||||
upload_ids if upload_ids.present?
|
||||
|
||||
[message]
|
||||
end
|
||||
|
||||
# caps uploads to maximum uploads allowed in message stream
|
||||
# and concats string elements
|
||||
def compress_messages_buffer(buffer, max_uploads:)
|
||||
compressed = []
|
||||
current_text = +""
|
||||
upload_count = 0
|
||||
|
||||
buffer.each do |item|
|
||||
if item.is_a?(String)
|
||||
current_text << item
|
||||
elsif item.is_a?(Hash)
|
||||
compressed << current_text if current_text.present?
|
||||
compressed << item
|
||||
current_text = +""
|
||||
upload_count += 1
|
||||
end
|
||||
end
|
||||
|
||||
compressed << current_text if current_text.present?
|
||||
|
||||
if upload_count > max_uploads
|
||||
counter = max_uploads - upload_count
|
||||
compressed.delete_if { |item| item.is_a?(Hash) && (counter += 1) > 0 }
|
||||
end
|
||||
|
||||
compressed = compressed[0] if compressed.length == 1 && compressed[0].is_a?(String)
|
||||
|
||||
compressed
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
expect(context.image_generation_scenario).to eq(
|
||||
{
|
||||
messages: [
|
||||
{ role: "user", parts: [{ text: "draw a cat" }] },
|
||||
{ role: "user", parts: [{ text: "user1: draw a cat" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [{ functionCall: { name: "draw", args: { picture: "Cat" } } }],
|
||||
@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
],
|
||||
},
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
{ role: "user", parts: [{ text: "draw another cat" }] },
|
||||
{ role: "user", parts: [{ text: "user1: draw another cat" }] },
|
||||
],
|
||||
system_instruction: context.system_insts,
|
||||
},
|
||||
@ -52,12 +52,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
expect(context.multi_turn_scenario).to eq(
|
||||
{
|
||||
messages: [
|
||||
{ role: "user", parts: [{ text: "This is a message by a user" }] },
|
||||
{ role: "user", parts: [{ text: "user1: This is a message by a user" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [{ text: "I'm a previous bot reply, that's why there's no user" }],
|
||||
},
|
||||
{ role: "user", parts: [{ text: "This is a new message by a user" }] },
|
||||
{ role: "user", parts: [{ text: "user1: This is a new message by a user" }] },
|
||||
{
|
||||
role: "model",
|
||||
parts: [
|
||||
|
@ -29,7 +29,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
@ -41,7 +41,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
|
||||
content = dialect.translate[1][:content]
|
||||
|
||||
expect(content).to eq(
|
||||
[{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }],
|
||||
[{ type: "text", text: "user1: hello" }, { type: "image_url", image_url: { url: image } }],
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -37,7 +37,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Nova do
|
||||
|
||||
it "properly formats messages with images" do
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] },
|
||||
{
|
||||
type: :user,
|
||||
id: "user1",
|
||||
content: ["What's in this image?", { upload_id: upload.id }],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new(messages: messages)
|
||||
|
@ -34,8 +34,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
||||
messages: [
|
||||
{
|
||||
type: :user,
|
||||
content: "Describe this image in a single sentence.",
|
||||
upload_ids: [upload.id],
|
||||
content: ["Describe this image in a single sentence.", { upload_id: upload.id }],
|
||||
},
|
||||
],
|
||||
)
|
||||
@ -49,10 +48,15 @@ RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
|
||||
|
||||
expect(translated_messages.length).to eq(1)
|
||||
|
||||
# no system message support here
|
||||
expected_user_message = {
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: prompt.messages.map { |m| m[:content] }.join("\n") },
|
||||
{
|
||||
type: "text",
|
||||
text:
|
||||
"You are a bot specializing in image captioning.\nDescribe this image in a single sentence.",
|
||||
},
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
|
@ -271,7 +271,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
@ -283,6 +283,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "user1: hello" },
|
||||
{
|
||||
type: "image",
|
||||
source: {
|
||||
@ -291,7 +292,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
data: encoded[0][:base64],
|
||||
},
|
||||
},
|
||||
{ type: "text", text: "user1: hello" },
|
||||
],
|
||||
},
|
||||
],
|
||||
|
@ -211,7 +211,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
@ -248,7 +248,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
{
|
||||
"role" => "user",
|
||||
"parts" => [
|
||||
{ "text" => "hello" },
|
||||
{ "text" => "user1: hello" },
|
||||
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
|
||||
],
|
||||
},
|
||||
|
@ -492,7 +492,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
messages: [type: :user, id: "user1", content: ["hello", { upload_id: upload100x100.id }]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
@ -517,13 +517,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "hello" },
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{encoded[0][:mime_type]};base64,#{encoded[0][:base64]}",
|
||||
},
|
||||
},
|
||||
{ type: "text", text: "hello" },
|
||||
],
|
||||
name: "user1",
|
||||
},
|
||||
|
@ -2,6 +2,36 @@
|
||||
|
||||
describe DiscourseAi::Completions::PromptMessagesBuilder do
|
||||
let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new }
|
||||
fab!(:user)
|
||||
fab!(:admin)
|
||||
fab!(:bot_user) { Fabricate(:user) }
|
||||
fab!(:other_user) { Fabricate(:user) }
|
||||
|
||||
fab!(:image_upload1) do
|
||||
Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
|
||||
end
|
||||
fab!(:image_upload2) do
|
||||
Fabricate(:upload, user: user, original_filename: "image.png", extension: "png")
|
||||
end
|
||||
|
||||
it "correctly merges user messages with uploads" do
|
||||
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1])
|
||||
builder.push(type: :user, content: "World", name: "Bob", upload_ids: [2])
|
||||
|
||||
messages = builder.to_a
|
||||
|
||||
# Check the structure of the merged message
|
||||
expect(messages.length).to eq(1)
|
||||
expect(messages[0][:type]).to eq(:user)
|
||||
|
||||
# The content should contain the text and both uploads
|
||||
content = messages[0][:content]
|
||||
expect(content).to be_an(Array)
|
||||
expect(content[0]).to eq("Alice: Hello")
|
||||
expect(content[1]).to eq({ upload_id: 1 })
|
||||
expect(content[2]).to eq("\nBob: World")
|
||||
expect(content[3]).to eq({ upload_id: 2 })
|
||||
end
|
||||
|
||||
it "should allow merging user messages" do
|
||||
builder.push(type: :user, content: "Hello", name: "Alice")
|
||||
@ -14,7 +44,7 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
||||
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2])
|
||||
|
||||
expect(builder.to_a).to eq(
|
||||
[{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }],
|
||||
[{ type: :user, content: ["Hello", { upload_id: 1 }, { upload_id: 2 }], name: "Alice" }],
|
||||
)
|
||||
end
|
||||
|
||||
@ -64,4 +94,319 @@ describe DiscourseAi::Completions::PromptMessagesBuilder do
|
||||
expect(content).to include("Alice")
|
||||
expect(content).to include("How do I solve this")
|
||||
end
|
||||
|
||||
describe ".messages_from_chat" do
|
||||
fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, bot_user]) }
|
||||
fab!(:dm_message1) do
|
||||
Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "Hello bot")
|
||||
end
|
||||
fab!(:dm_message2) do
|
||||
Fabricate(:chat_message, chat_channel: dm_channel, user: bot_user, message: "Hello human")
|
||||
end
|
||||
fab!(:dm_message3) do
|
||||
Fabricate(:chat_message, chat_channel: dm_channel, user: user, message: "How are you?")
|
||||
end
|
||||
|
||||
fab!(:public_channel) { Fabricate(:category_channel) }
|
||||
fab!(:public_message1) do
|
||||
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Hello everyone")
|
||||
end
|
||||
fab!(:public_message2) do
|
||||
Fabricate(:chat_message, chat_channel: public_channel, user: bot_user, message: "Hi there")
|
||||
end
|
||||
|
||||
fab!(:thread_original) do
|
||||
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Thread starter")
|
||||
end
|
||||
fab!(:thread) do
|
||||
Fabricate(:chat_thread, channel: public_channel, original_message: thread_original)
|
||||
end
|
||||
fab!(:thread_reply1) do
|
||||
Fabricate(
|
||||
:chat_message,
|
||||
chat_channel: public_channel,
|
||||
user: other_user,
|
||||
message: "Thread reply",
|
||||
thread: thread,
|
||||
)
|
||||
end
|
||||
|
||||
fab!(:upload) { Fabricate(:upload, user: user) }
|
||||
fab!(:message_with_upload) do
|
||||
Fabricate(
|
||||
:chat_message,
|
||||
chat_channel: dm_channel,
|
||||
user: user,
|
||||
message: "Check this image",
|
||||
upload_ids: [upload.id],
|
||||
)
|
||||
end
|
||||
|
||||
it "processes messages from direct message channels" do
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
dm_message3,
|
||||
channel: dm_channel,
|
||||
context_post_ids: nil,
|
||||
max_messages: 10,
|
||||
include_uploads: false,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
# this is all we got cause it is assuming threading
|
||||
expect(context).to eq([{ type: :user, content: "How are you?", name: user.username }])
|
||||
end
|
||||
|
||||
it "includes uploads when include_uploads is true" do
|
||||
message_with_upload.reload
|
||||
expect(message_with_upload.uploads).to include(upload)
|
||||
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
message_with_upload,
|
||||
channel: dm_channel,
|
||||
context_post_ids: nil,
|
||||
max_messages: 10,
|
||||
include_uploads: true,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
# Find the message with upload
|
||||
message = context.find { |m| m[:content] == ["Check this image", { upload_id: upload.id }] }
|
||||
expect(message).to be_present
|
||||
end
|
||||
|
||||
it "doesn't include uploads when include_uploads is false" do
|
||||
# Make sure the upload is associated with the message
|
||||
message_with_upload.reload
|
||||
expect(message_with_upload.uploads).to include(upload)
|
||||
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
message_with_upload,
|
||||
channel: dm_channel,
|
||||
context_post_ids: nil,
|
||||
max_messages: 10,
|
||||
include_uploads: false,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
# Find the message with upload
|
||||
message = context.find { |m| m[:content] == "Check this image" }
|
||||
expect(message).to be_present
|
||||
expect(message[:upload_ids]).to be_nil
|
||||
end
|
||||
|
||||
it "properly handles uploads in public channels with multiple users" do
|
||||
_first_message =
|
||||
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "First message")
|
||||
|
||||
_message_with_upload =
|
||||
Fabricate(
|
||||
:chat_message,
|
||||
chat_channel: public_channel,
|
||||
user: other_user,
|
||||
message: "Message with image",
|
||||
upload_ids: [upload.id],
|
||||
)
|
||||
|
||||
last_message =
|
||||
Fabricate(:chat_message, chat_channel: public_channel, user: user, message: "Final message")
|
||||
|
||||
context =
|
||||
described_class.messages_from_chat(
|
||||
last_message,
|
||||
channel: public_channel,
|
||||
context_post_ids: nil,
|
||||
max_messages: 3,
|
||||
include_uploads: true,
|
||||
bot_user_ids: [bot_user.id],
|
||||
instruction_message: nil,
|
||||
)
|
||||
|
||||
expect(context.length).to eq(1)
|
||||
content = context.first[:content]
|
||||
|
||||
expect(content.length).to eq(3)
|
||||
expect(content[0]).to include("First message")
|
||||
expect(content[0]).to include("Message with image")
|
||||
expect(content[1]).to include({ upload_id: upload.id })
|
||||
expect(content[2]).to include("Final message")
|
||||
end
|
||||
end
|
||||
|
||||
describe ".messages_from_post" do
|
||||
fab!(:pm) do
|
||||
Fabricate(
|
||||
:private_message_topic,
|
||||
title: "This is my special PM",
|
||||
user: user,
|
||||
topic_allowed_users: [
|
||||
Fabricate.build(:topic_allowed_user, user: user),
|
||||
Fabricate.build(:topic_allowed_user, user: bot_user),
|
||||
],
|
||||
)
|
||||
end
|
||||
fab!(:first_post) do
|
||||
Fabricate(:post, topic: pm, user: user, post_number: 1, raw: "This is a reply by the user")
|
||||
end
|
||||
fab!(:second_post) do
|
||||
Fabricate(:post, topic: pm, user: bot_user, post_number: 2, raw: "This is a bot reply")
|
||||
end
|
||||
fab!(:third_post) do
|
||||
Fabricate(
|
||||
:post,
|
||||
topic: pm,
|
||||
user: user,
|
||||
post_number: 3,
|
||||
raw: "This is a second reply by the user",
|
||||
)
|
||||
end
|
||||
|
||||
it "handles uploads correctly in topic style messages" do
|
||||
# Use Discourse's upload format in the post raw content
|
||||
upload_markdown = ""
|
||||
|
||||
post_with_upload =
|
||||
Fabricate(
|
||||
:post,
|
||||
topic: pm,
|
||||
user: admin,
|
||||
raw: "This is the original #{upload_markdown} I just added",
|
||||
)
|
||||
|
||||
UploadReference.create!(target: post_with_upload, upload: image_upload1)
|
||||
|
||||
upload2_markdown = ""
|
||||
|
||||
post2_with_upload =
|
||||
Fabricate(
|
||||
:post,
|
||||
topic: pm,
|
||||
user: admin,
|
||||
raw: "This post has a different image #{upload2_markdown} I just added",
|
||||
)
|
||||
|
||||
UploadReference.create!(target: post2_with_upload, upload: image_upload2)
|
||||
|
||||
messages =
|
||||
described_class.messages_from_post(
|
||||
post2_with_upload,
|
||||
style: :topic,
|
||||
max_posts: 3,
|
||||
bot_usernames: [bot_user.username],
|
||||
include_uploads: true,
|
||||
)
|
||||
|
||||
# this is not quite ideal yet, images are attached at the end of the post
|
||||
# long term we may want to extract them out using a regex and create N parts
|
||||
# so people can talk about multiple images in a single post
|
||||
# this is the initial ground work though
|
||||
|
||||
expect(messages.length).to eq(1)
|
||||
content = messages[0][:content]
|
||||
|
||||
# first part
|
||||
# first image
|
||||
# second part
|
||||
# second image
|
||||
expect(content.length).to eq(4)
|
||||
expect(content[0]).to include("This is the original")
|
||||
expect(content[1]).to eq({ upload_id: image_upload1.id })
|
||||
expect(content[2]).to include("different image")
|
||||
expect(content[3]).to eq({ upload_id: image_upload2.id })
|
||||
end
|
||||
|
||||
context "with limited context" do
|
||||
it "respects max_context_posts" do
|
||||
context =
|
||||
described_class.messages_from_post(
|
||||
third_post,
|
||||
max_posts: 1,
|
||||
bot_usernames: [bot_user.username],
|
||||
include_uploads: false,
|
||||
)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[{ type: :user, id: user.username, content: third_post.raw }],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "includes previous posts ordered by post_number" do
|
||||
context =
|
||||
described_class.messages_from_post(
|
||||
third_post,
|
||||
max_posts: 10,
|
||||
bot_usernames: [bot_user.username],
|
||||
include_uploads: false,
|
||||
)
|
||||
|
||||
expect(context).to eq(
|
||||
[
|
||||
{ type: :user, content: "This is a reply by the user", id: user.username },
|
||||
{ type: :model, content: "This is a bot reply" },
|
||||
{ type: :user, content: "This is a second reply by the user", id: user.username },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "only include regular posts" do
|
||||
first_post.update!(post_type: Post.types[:whisper])
|
||||
|
||||
context =
|
||||
described_class.messages_from_post(
|
||||
third_post,
|
||||
max_posts: 10,
|
||||
bot_usernames: [bot_user.username],
|
||||
include_uploads: false,
|
||||
)
|
||||
|
||||
# skips leading model reply which makes no sense cause first post was whisper
|
||||
expect(context).to eq(
|
||||
[{ type: :user, content: "This is a second reply by the user", id: user.username }],
|
||||
)
|
||||
end
|
||||
|
||||
context "with custom prompts" do
|
||||
it "When post custom prompt is present, we use that instead of the post content" do
|
||||
custom_prompt = [
|
||||
[
|
||||
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
|
||||
"time",
|
||||
"tool_call",
|
||||
],
|
||||
[
|
||||
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
|
||||
"time",
|
||||
"tool",
|
||||
],
|
||||
["I replied to the time command", bot_user.username],
|
||||
]
|
||||
|
||||
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
|
||||
|
||||
context =
|
||||
described_class.messages_from_post(
|
||||
third_post,
|
||||
max_posts: 10,
|
||||
bot_usernames: [bot_user.username],
|
||||
include_uploads: false,
|
||||
)
|
||||
|
||||
expect(context).to eq(
|
||||
[
|
||||
{ type: :user, content: "This is a reply by the user", id: user.username },
|
||||
{ type: :tool_call, content: custom_prompt.first.first, id: "time" },
|
||||
{ type: :tool, id: "time", content: custom_prompt.second.first },
|
||||
{ type: :model, content: custom_prompt.third.first },
|
||||
{ type: :user, content: "This is a second reply by the user", id: user.username },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -25,34 +25,21 @@ RSpec.describe DiscourseAi::Completions::Prompt do
|
||||
end
|
||||
|
||||
describe "image support" do
|
||||
it "allows adding uploads to messages" do
|
||||
it "allows adding uploads inline in messages" do
|
||||
upload = UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||
|
||||
prompt.max_pixels = 300
|
||||
prompt.push(type: :user, content: "hello", upload_ids: [upload.id])
|
||||
prompt.push(
|
||||
type: :user,
|
||||
content: ["this is an image", { upload_id: upload.id }, "this was an image"],
|
||||
)
|
||||
|
||||
expect(prompt.messages.last[:upload_ids]).to eq([upload.id])
|
||||
expect(prompt.max_pixels).to eq(300)
|
||||
encoded = prompt.content_with_encoded_uploads(prompt.messages.last[:content])
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
|
||||
expect(encoded.length).to eq(1)
|
||||
expect(encoded[0][:mime_type]).to eq("image/jpeg")
|
||||
|
||||
old_base64 = encoded[0][:base64]
|
||||
|
||||
prompt.max_pixels = 1_000_000
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
|
||||
expect(encoded.length).to eq(1)
|
||||
expect(encoded[0][:mime_type]).to eq("image/jpeg")
|
||||
|
||||
new_base64 = encoded[0][:base64]
|
||||
|
||||
expect(new_base64.length).to be > old_base64.length
|
||||
expect(new_base64.length).to be > 0
|
||||
expect(old_base64.length).to be > 0
|
||||
expect(encoded.length).to eq(3)
|
||||
expect(encoded[0]).to eq("this is an image")
|
||||
expect(encoded[1][:mime_type]).to eq("image/jpeg")
|
||||
expect(encoded[2]).to eq("this was an image")
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -52,7 +52,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new)
|
||||
bot.reply(
|
||||
{ conversation_context: [{ type: :user, content: "test" }] },
|
||||
DiscourseAi::AiBot::BotContext.new(messages: [{ type: :user, content: "test" }]),
|
||||
) do |_partial, _cancel, _placeholder|
|
||||
# we just need the block so bot has something to call with results
|
||||
end
|
||||
@ -74,7 +74,10 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||
|
||||
HTML
|
||||
|
||||
context = { conversation_context: [{ type: :user, content: "Does my site has tags?" }] }
|
||||
context =
|
||||
DiscourseAi::AiBot::BotContext.new(
|
||||
messages: [{ type: :user, content: "Does my site has tags?" }],
|
||||
)
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(llm_responses) do
|
||||
bot.reply(context) do |_bot_reply_post, cancel, placeholder|
|
||||
|
@ -36,13 +36,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
end
|
||||
|
||||
let(:context) do
|
||||
{
|
||||
DiscourseAi::AiBot::BotContext.new(
|
||||
site_url: Discourse.base_url,
|
||||
site_title: "test site title",
|
||||
site_description: "test site description",
|
||||
time: Time.zone.now,
|
||||
participants: topic_with_users.allowed_users.map(&:username).join(", "),
|
||||
}
|
||||
)
|
||||
end
|
||||
|
||||
fab!(:admin)
|
||||
@ -307,7 +307,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
|
||||
|
||||
let(:with_cc) do
|
||||
context.merge(conversation_context: [{ content: "Tell me the time", type: :user }])
|
||||
context.messages = [{ content: "Tell me the time", type: :user }]
|
||||
context
|
||||
end
|
||||
|
||||
context "when a persona has no uploads" do
|
||||
@ -345,17 +346,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
|
||||
|
||||
# this means that we will consolidate
|
||||
ctx =
|
||||
with_cc.merge(
|
||||
conversation_context: [
|
||||
{ content: "Tell me the time", type: :user },
|
||||
{ content: "the time is 1", type: :model },
|
||||
{ content: "in france?", type: :user },
|
||||
],
|
||||
)
|
||||
context.messages = [
|
||||
{ content: "Tell me the time", type: :user },
|
||||
{ content: "the time is 1", type: :model },
|
||||
{ content: "in france?", type: :user },
|
||||
]
|
||||
|
||||
DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do
|
||||
custom_persona.craft_prompt(ctx).messages.first[:content]
|
||||
custom_persona.craft_prompt(context).messages.first[:content]
|
||||
end
|
||||
|
||||
message =
|
||||
@ -397,7 +395,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
|
||||
|
||||
EmbeddingsGenerationStubs.hugging_face_service(
|
||||
with_cc.dig(:conversation_context, 0, :content),
|
||||
with_cc.messages.dig(0, :content),
|
||||
prompt_cc_embeddings,
|
||||
)
|
||||
end
|
||||
|
@ -267,7 +267,10 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||
prompts = inner_prompts
|
||||
end
|
||||
|
||||
expect(prompts[0].messages[1][:upload_ids]).to eq([upload.id])
|
||||
content = prompts[0].messages[1][:content]
|
||||
|
||||
expect(content).to include({ upload_id: upload.id })
|
||||
|
||||
expect(prompts[0].max_pixels).to eq(1000)
|
||||
|
||||
post.topic.reload
|
||||
@ -1154,79 +1157,4 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||
expect(playground.available_bot_usernames).to include(persona.user.username)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
context "with limited context" do
|
||||
before do
|
||||
@old_persona = playground.bot.persona
|
||||
persona = Fabricate(:ai_persona, max_context_posts: 1)
|
||||
playground.bot.persona = persona.class_instance.new
|
||||
end
|
||||
|
||||
after { playground.bot.persona = @old_persona }
|
||||
|
||||
it "respects max_context_post" do
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[{ type: :user, id: user.username, content: third_post.raw }],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
xit "includes previous posts ordered by post_number" do
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: :user, id: user.username, content: third_post.raw },
|
||||
{ type: :model, content: second_post.raw },
|
||||
{ type: :user, id: user.username, content: first_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
xit "only include regular posts" do
|
||||
first_post.update!(post_type: Post.types[:whisper])
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
# skips leading model reply which makes no sense cause first post was whisper
|
||||
expect(context).to contain_exactly(
|
||||
*[{ type: :user, id: user.username, content: third_post.raw }],
|
||||
)
|
||||
end
|
||||
|
||||
context "with custom prompts" do
|
||||
it "When post custom prompt is present, we use that instead of the post content" do
|
||||
custom_prompt = [
|
||||
[
|
||||
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
|
||||
"time",
|
||||
"tool_call",
|
||||
],
|
||||
[
|
||||
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
|
||||
"time",
|
||||
"tool",
|
||||
],
|
||||
["I replied to the time command", bot_user.username],
|
||||
]
|
||||
|
||||
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
|
||||
|
||||
context = playground.conversation_context(third_post)
|
||||
|
||||
expect(context).to contain_exactly(
|
||||
*[
|
||||
{ type: :user, id: user.username, content: first_post.raw },
|
||||
{ type: :tool_call, content: custom_prompt.first.first, id: "time" },
|
||||
{ type: :tool, id: "time", content: custom_prompt.second.first },
|
||||
{ type: :model, content: custom_prompt.third.first },
|
||||
{ type: :user, id: user.username, content: third_post.raw },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -34,9 +34,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::CreateArtifact do
|
||||
{ html_body: "hello" },
|
||||
bot_user: Fabricate(:user),
|
||||
llm: llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
tool.parameters = { name: "hello", specification: "hello spec" }
|
||||
|
@ -15,9 +15,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
let(:dall_e) do
|
||||
described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user, context: {})
|
||||
end
|
||||
let(:dall_e) { described_class.new({ prompts: prompts }, llm: llm, bot_user: bot_user) }
|
||||
|
||||
let(:base64_image) do
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
@ -30,8 +28,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||
{ prompts: ["a cat"], aspect_ratio: "tall" },
|
||||
llm: llm,
|
||||
bot_user: bot_user,
|
||||
context: {
|
||||
},
|
||||
)
|
||||
|
||||
data = [{ b64_json: base64_image, revised_prompt: "a tall cat" }]
|
||||
|
@ -9,8 +9,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||
{ prompts: prompts, seeds: [99, 32] },
|
||||
bot_user: bot_user,
|
||||
llm: llm,
|
||||
context: {
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new,
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -25,9 +25,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
|
||||
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post2.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -46,9 +44,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
|
||||
{ url: "invalid-url" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -62,9 +58,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
|
||||
{ url: "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/99999" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -97,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
|
||||
{ url: "https://example.com" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -128,9 +120,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ReadArtifact do
|
||||
{ url: "https://example.com" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
|
@ -56,9 +56,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||
persona_options: {
|
||||
"read_private" => true,
|
||||
},
|
||||
context: {
|
||||
user: admin,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(user: admin),
|
||||
)
|
||||
results = tool.invoke
|
||||
expect(results[:content]).to include("hello there")
|
||||
@ -68,9 +66,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||
{ topic_id: topic_with_tags.id, post_numbers: [post1.post_number] },
|
||||
bot_user: bot_user,
|
||||
llm: llm,
|
||||
context: {
|
||||
user: admin,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(user: admin),
|
||||
)
|
||||
|
||||
results = tool.invoke
|
||||
|
@ -60,9 +60,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
persona_options: persona_options,
|
||||
bot_user: bot_user,
|
||||
llm: llm,
|
||||
context: {
|
||||
user: user,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(user: user),
|
||||
)
|
||||
|
||||
expect(search.options[:base_query]).to eq("#funny")
|
||||
|
@ -47,9 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
persona_options: {
|
||||
"update_algorithm" => "full",
|
||||
},
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -93,9 +91,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
persona_options: {
|
||||
"update_algorithm" => "full",
|
||||
},
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -119,9 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
{ artifact_id: artifact.id, instructions: "Invalid update" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -135,9 +129,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
{ artifact_id: -1, instructions: "Update something" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -163,9 +155,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
persona_options: {
|
||||
"update_algorithm" => "full",
|
||||
},
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
tool.invoke {}
|
||||
@ -196,9 +186,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
persona_options: {
|
||||
"update_algorithm" => "full",
|
||||
},
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
.invoke {}
|
||||
end
|
||||
@ -224,9 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
persona_options: {
|
||||
"update_algorithm" => "full",
|
||||
},
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
)
|
||||
|
||||
result = tool.invoke {}
|
||||
@ -276,9 +262,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
persona_options: {
|
||||
"update_algorithm" => "diff",
|
||||
},
|
||||
@ -346,9 +330,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::UpdateArtifact do
|
||||
{ artifact_id: artifact.id, instructions: "Change the text to Updated and color to red" },
|
||||
bot_user: bot_user,
|
||||
llm: llm_model.to_llm,
|
||||
context: {
|
||||
post_id: post.id,
|
||||
},
|
||||
context: DiscourseAi::AiBot::BotContext.new(messages: [], post: post),
|
||||
persona_options: {
|
||||
"update_algorithm" => "diff",
|
||||
},
|
||||
|
@ -255,11 +255,12 @@ RSpec.describe DiscourseAi::AiModeration::SpamScanner do
|
||||
prompt = _prompts.first
|
||||
end
|
||||
|
||||
content = prompt.messages[1][:content]
|
||||
# its an array so lets just stringify it to make testing easier
|
||||
content = prompt.messages[1][:content][0]
|
||||
expect(content).to include(post.topic.title)
|
||||
expect(content).to include(post.raw)
|
||||
|
||||
upload_ids = prompt.messages[1][:upload_ids]
|
||||
upload_ids = prompt.messages[1][:content].map { |m| m[:upload_id] if m.is_a?(Hash) }.compact
|
||||
expect(upload_ids).to be_present
|
||||
expect(upload_ids).to eq(post.upload_ids)
|
||||
|
||||
|
@ -199,7 +199,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||
|
||||
triage_prompt = DiscourseAi::Completions::Llm.prompts.last
|
||||
|
||||
expect(triage_prompt.messages.last[:upload_ids]).to contain_exactly(post_upload.id)
|
||||
expect(triage_prompt.messages.last[:content].last).to eq({ upload_id: post_upload.id })
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -5,6 +5,7 @@ RSpec.describe AiTool do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "bananas are a tasty fruit") }
|
||||
fab!(:bot_user) { Discourse.system_user }
|
||||
|
||||
def create_tool(
|
||||
parameters: nil,
|
||||
@ -16,7 +17,8 @@ RSpec.describe AiTool do
|
||||
name: "test #{SecureRandom.uuid}",
|
||||
tool_name: "test_#{SecureRandom.uuid.underscore}",
|
||||
description: "test",
|
||||
parameters: parameters || [{ name: "query", type: "string", desciption: "perform a search" }],
|
||||
parameters:
|
||||
parameters || [{ name: "query", type: "string", description: "perform a search" }],
|
||||
script: script || "function invoke(params) { return params; }",
|
||||
created_by_id: 1,
|
||||
summary: "Test tool summary",
|
||||
@ -32,11 +34,11 @@ RSpec.describe AiTool do
|
||||
{
|
||||
name: tool.tool_name,
|
||||
description: "test",
|
||||
parameters: [{ name: "query", type: "string", desciption: "perform a search" }],
|
||||
parameters: [{ name: "query", type: "string", description: "perform a search" }],
|
||||
},
|
||||
)
|
||||
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
expect(runner.invoke).to eq("query" => "test")
|
||||
end
|
||||
@ -57,7 +59,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "data" => "test data" }, llm: nil, bot_user: nil)
|
||||
|
||||
stub_request(verb, "https://example.com/api").with(
|
||||
body: "{\"data\":\"test data\"}",
|
||||
@ -83,7 +85,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
stub_request(:get, "https://example.com/test").with(
|
||||
headers: {
|
||||
@ -110,7 +112,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({}, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({}, llm: nil, bot_user: nil)
|
||||
|
||||
stub_request(:get, "https://example.com/").to_return(
|
||||
status: 200,
|
||||
@ -134,7 +136,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
stub_request(:get, "https://example.com/test").with(
|
||||
headers: {
|
||||
@ -160,13 +162,16 @@ RSpec.describe AiTool do
|
||||
}
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
stub_request(:get, "https://example.com/test").to_return do
|
||||
sleep 0.01
|
||||
{ status: 200, body: "Hello World", headers: {} }
|
||||
end
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
runner.timeout = 10
|
||||
|
||||
@ -184,7 +189,7 @@ RSpec.describe AiTool do
|
||||
|
||||
tool = create_tool(script: script)
|
||||
|
||||
runner = tool.runner({}, llm: llm, bot_user: nil, context: {})
|
||||
runner = tool.runner({}, llm: llm, bot_user: nil)
|
||||
result = runner.invoke
|
||||
|
||||
expect(result).to eq("Hello")
|
||||
@ -209,7 +214,7 @@ RSpec.describe AiTool do
|
||||
responses = ["Hello ", "World"]
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||
runner = tool.runner({}, llm: llm, bot_user: nil, context: {})
|
||||
runner = tool.runner({}, llm: llm, bot_user: nil)
|
||||
result = runner.invoke
|
||||
prompts = _prompts
|
||||
end
|
||||
@ -232,7 +237,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "test" }, llm: nil, bot_user: nil)
|
||||
|
||||
runner.timeout = 5
|
||||
|
||||
@ -295,7 +300,7 @@ RSpec.describe AiTool do
|
||||
|
||||
RagDocumentFragment.link_target_and_uploads(tool, [upload1.id, upload2.id])
|
||||
|
||||
result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke
|
||||
result = tool.runner({}, llm: nil, bot_user: nil).invoke
|
||||
|
||||
expected = [
|
||||
[{ "fragment" => "44 45 46 47 48 49 50", "metadata" => nil }],
|
||||
@ -316,7 +321,7 @@ RSpec.describe AiTool do
|
||||
# this part of the API is a bit awkward, maybe we should do it
|
||||
# automatically
|
||||
RagDocumentFragment.update_target_uploads(tool, [upload1.id, upload2.id])
|
||||
result = tool.runner({}, llm: nil, bot_user: nil, context: {}).invoke
|
||||
result = tool.runner({}, llm: nil, bot_user: nil).invoke
|
||||
|
||||
expected = [
|
||||
[{ "fragment" => "48 49 50", "metadata" => nil }],
|
||||
@ -340,7 +345,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "topic_id" => topic.id }, llm: nil, bot_user: nil)
|
||||
|
||||
result = runner.invoke
|
||||
|
||||
@ -364,7 +369,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "post_id" => post.id }, llm: nil, bot_user: nil)
|
||||
|
||||
result = runner.invoke
|
||||
post_hash = result["post"]
|
||||
@ -393,7 +398,7 @@ RSpec.describe AiTool do
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {})
|
||||
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil)
|
||||
|
||||
result = runner.invoke
|
||||
|
||||
@ -401,4 +406,158 @@ RSpec.describe AiTool do
|
||||
expect(result["rows"].first["title"]).to eq(topic.title)
|
||||
end
|
||||
end
|
||||
|
||||
context "when using the chat API" do
|
||||
before(:each) do
|
||||
skip "Chat plugin tests skipped because Chat module is not defined." unless defined?(Chat)
|
||||
SiteSetting.chat_enabled = true
|
||||
end
|
||||
|
||||
fab!(:chat_user) { Fabricate(:user) }
|
||||
fab!(:chat_channel) do
|
||||
Fabricate(:chat_channel).tap do |channel|
|
||||
Fabricate(
|
||||
:user_chat_channel_membership,
|
||||
user: chat_user,
|
||||
chat_channel: channel,
|
||||
following: true,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
it "can create a chat message" do
|
||||
script = <<~JS
|
||||
function invoke(params) {
|
||||
return discourse.createChatMessage({
|
||||
channel_name: params.channel_name,
|
||||
username: params.username,
|
||||
message: params.message
|
||||
});
|
||||
}
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner =
|
||||
tool.runner(
|
||||
{
|
||||
"channel_name" => chat_channel.name,
|
||||
"username" => chat_user.username,
|
||||
"message" => "Hello from the tool!",
|
||||
},
|
||||
llm: nil,
|
||||
bot_user: bot_user, # The user *running* the tool doesn't affect sender
|
||||
)
|
||||
|
||||
initial_message_count = Chat::Message.count
|
||||
result = runner.invoke
|
||||
|
||||
expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
|
||||
expect(result["message"]).to eq("Hello from the tool!")
|
||||
expect(result["created_at"]).to be_present
|
||||
expect(result).not_to have_key("error")
|
||||
|
||||
# Verify message was actually created in the database
|
||||
expect(Chat::Message.count).to eq(initial_message_count + 1)
|
||||
created_message = Chat::Message.find_by(id: result["message_id"])
|
||||
|
||||
expect(created_message).not_to be_nil
|
||||
expect(created_message.message).to eq("Hello from the tool!")
|
||||
expect(created_message.user_id).to eq(chat_user.id) # Message is sent AS the specified user
|
||||
expect(created_message.chat_channel_id).to eq(chat_channel.id)
|
||||
end
|
||||
|
||||
it "can create a chat message using channel slug" do
|
||||
chat_channel.update!(name: "My Test Channel", slug: "my-test-channel")
|
||||
expect(chat_channel.slug).to eq("my-test-channel")
|
||||
|
||||
script = <<~JS
|
||||
function invoke(params) {
|
||||
return discourse.createChatMessage({
|
||||
channel_name: params.channel_slug, // Using slug here
|
||||
username: params.username,
|
||||
message: params.message
|
||||
});
|
||||
}
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner =
|
||||
tool.runner(
|
||||
{
|
||||
"channel_slug" => chat_channel.slug,
|
||||
"username" => chat_user.username,
|
||||
"message" => "Hello via slug!",
|
||||
},
|
||||
llm: nil,
|
||||
bot_user: bot_user,
|
||||
)
|
||||
|
||||
result = runner.invoke
|
||||
|
||||
expect(result["success"]).to eq(true), "Tool invocation failed: #{result["error"]}"
|
||||
# see: https://github.com/rubyjs/mini_racer/issues/348
|
||||
# expect(result["message_id"]).to be_a(Integer)
|
||||
|
||||
created_message = Chat::Message.find_by(id: result["message_id"])
|
||||
expect(created_message).not_to be_nil
|
||||
expect(created_message.message).to eq("Hello via slug!")
|
||||
expect(created_message.chat_channel_id).to eq(chat_channel.id)
|
||||
end
|
||||
|
||||
it "returns an error if the channel is not found" do
|
||||
script = <<~JS
|
||||
function invoke(params) {
|
||||
return discourse.createChatMessage({
|
||||
channel_name: "non_existent_channel",
|
||||
username: params.username,
|
||||
message: params.message
|
||||
});
|
||||
}
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner =
|
||||
tool.runner(
|
||||
{ "username" => chat_user.username, "message" => "Test" },
|
||||
llm: nil,
|
||||
bot_user: bot_user,
|
||||
)
|
||||
|
||||
initial_message_count = Chat::Message.count
|
||||
expect { runner.invoke }.to raise_error(
|
||||
MiniRacer::RuntimeError,
|
||||
/Channel not found: non_existent_channel/,
|
||||
)
|
||||
|
||||
expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
|
||||
end
|
||||
|
||||
it "returns an error if the user is not found" do
|
||||
script = <<~JS
|
||||
function invoke(params) {
|
||||
return discourse.createChatMessage({
|
||||
channel_name: params.channel_name,
|
||||
username: "non_existent_user",
|
||||
message: params.message
|
||||
});
|
||||
}
|
||||
JS
|
||||
|
||||
tool = create_tool(script: script)
|
||||
runner =
|
||||
tool.runner(
|
||||
{ "channel_name" => chat_channel.name, "message" => "Test" },
|
||||
llm: nil,
|
||||
bot_user: bot_user,
|
||||
)
|
||||
|
||||
initial_message_count = Chat::Message.count
|
||||
expect { runner.invoke }.to raise_error(
|
||||
MiniRacer::RuntimeError,
|
||||
/User not found: non_existent_user/,
|
||||
)
|
||||
|
||||
expect(Chat::Message.count).to eq(initial_message_count) # Verify no message created
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user