discourse-ai/lib/personas/tool_runner.rb
Sam e15984029d
FEATURE: allow tools to amend personas (#1250)
Add API methods to AI tools for reading and updating personas, enabling
more flexible AI workflows. This allows custom tools to:

- Fetch persona information through discourse.getPersona()
- Update personas with modified settings via discourse.updatePersona()
- Also update using persona.update()

These APIs enable new use cases like "trainable" moderation bots, where
users with appropriate permissions can set and refine moderation rules
through direct chat interactions, without needing admin panel access.

Also adds a special API scope which allows people to lean on API
for similar actions

Additionally adds a rather powerful hidden feature can allow custom tools
to inject content into the context unconditionally it can be used for memory and similar features
2025-04-09 15:48:25 +10:00

706 lines
22 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Personas
class ToolRunner
attr_reader :tool, :parameters, :llm
attr_accessor :running_attached_function, :timeout, :custom_raw
TooManyRequestsError = Class.new(StandardError)
DEFAULT_TIMEOUT = 2000
MAX_MEMORY = 10_000_000
MARSHAL_STACK_DEPTH = 20
MAX_HTTP_REQUESTS = 20
def initialize(parameters:, llm:, bot_user:, context: nil, tool:, timeout: nil)
if context && !context.is_a?(DiscourseAi::Personas::BotContext)
raise ArgumentError, "context must be a BotContext object"
end
context ||= DiscourseAi::Personas::BotContext.new
@parameters = parameters
@llm = llm
@bot_user = bot_user
@context = context
@tool = tool
@timeout = timeout || DEFAULT_TIMEOUT
@running_attached_function = false
@http_requests_made = 0
end
def mini_racer_context
@mini_racer_context ||=
begin
ctx =
MiniRacer::Context.new(
max_memory: MAX_MEMORY,
marshal_stack_depth: MARSHAL_STACK_DEPTH,
)
attach_truncate(ctx)
attach_http(ctx)
attach_index(ctx)
attach_upload(ctx)
attach_chain(ctx)
attach_discourse(ctx)
ctx.eval(framework_script)
ctx
end
end
def framework_script
http_methods = %i[get post put patch delete].map { |method| <<~JS }.join("\n")
#{method}: function(url, options) {
return _http_#{method}(url, options);
},
JS
<<~JS
const http = {
#{http_methods}
};
const llm = {
truncate: _llm_truncate,
generate: _llm_generate,
};
const index = {
search: _index_search,
}
const upload = {
create: _upload_create,
}
const chain = {
setCustomRaw: _chain_set_custom_raw,
};
const discourse = {
search: function(params) {
return _discourse_search(params);
},
updatePersona: function(persona_id_or_name, updates) {
const result = _discourse_update_persona(persona_id_or_name, updates);
if (result.error) {
throw new Error(result.error);
}
return result;
},
getPost: _discourse_get_post,
getTopic: _discourse_get_topic,
getUser: _discourse_get_user,
getPersona: function(name) {
const personaDetails = _discourse_get_persona(name);
if (personaDetails.error) {
throw new Error(personaDetails.error);
}
// merge result.persona with {}..
return Object.assign({
update: function(updates) {
const result = _discourse_update_persona(name, updates);
if (result.error) {
throw new Error(result.error);
}
return result;
},
respondTo: function(params) {
const result = _discourse_respond_to_persona(name, params);
if (result.error) {
throw new Error(result.error);
}
return result;
}
}, personaDetails.persona);
},
createChatMessage: function(params) {
const result = _discourse_create_chat_message(params);
if (result.error) {
throw new Error(result.error);
}
return result;
},
};
const context = #{JSON.generate(@context.to_json)};
function details() { return ""; };
JS
end
def details
eval_with_timeout("details()")
end
def eval_with_timeout(script, timeout: nil)
timeout ||= @timeout
mutex = Mutex.new
done = false
elapsed = 0
t =
Thread.new do
begin
while !done
# this is not accurate. but reasonable enough for a timeout
sleep(0.001)
elapsed += 1 if !self.running_attached_function
if elapsed > timeout
mutex.synchronize { mini_racer_context.stop unless done }
break
end
end
rescue => e
STDERR.puts e
STDERR.puts "FAILED TO TERMINATE DUE TO TIMEOUT"
end
end
rval = mini_racer_context.eval(script)
mutex.synchronize { done = true }
# ensure we do not leak a thread in state
t.join
t = nil
rval
ensure
# exceptions need to be handled
t&.join
end
def invoke
mini_racer_context.eval(tool.script)
eval_with_timeout("invoke(#{JSON.generate(parameters)})")
rescue MiniRacer::ScriptTerminatedError
{ error: "Script terminated due to timeout" }
end
def has_custom_context?
mini_racer_context.eval(tool.script)
mini_racer_context.eval("typeof customContext === 'function'")
rescue StandardError
false
end
def custom_context
mini_racer_context.eval(tool.script)
mini_racer_context.eval("customContext()")
rescue StandardError
nil
end
private
MAX_FRAGMENTS = 200
def rag_search(query, filenames: nil, limit: 10)
limit = limit.to_i
return [] if limit < 1
limit = [MAX_FRAGMENTS, limit].min
upload_refs =
UploadReference.where(target_id: tool.id, target_type: "AiTool").pluck(:upload_id)
if filenames
upload_refs = Upload.where(id: upload_refs).where(original_filename: filenames).pluck(:id)
end
return [] if upload_refs.empty?
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
fragment_ids =
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
SQL
end
.map(&:rag_document_fragment_id)
fragments =
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(
:id,
:fragment,
:metadata,
)
mapped = {}
fragments.each do |id, fragment, metadata|
mapped[id] = { fragment: fragment, metadata: metadata }
end
fragment_ids.take(limit).map { |fragment_id| mapped[fragment_id] }
end
def attach_truncate(mini_racer_context)
mini_racer_context.attach(
"_llm_truncate",
->(text, length) { @llm.tokenizer.truncate(text, length) },
)
mini_racer_context.attach(
"_llm_generate",
->(prompt) do
in_attached_function do
@llm.generate(
convert_js_prompt_to_ruby(prompt),
user: llm_user,
feature_name: "custom_tool_#{tool.name}",
)
end
end,
)
end
def convert_js_prompt_to_ruby(prompt)
if prompt.is_a?(String)
prompt
elsif prompt.is_a?(Hash)
messages = prompt["messages"]
if messages.blank? || !messages.is_a?(Array)
raise Discourse::InvalidParameters.new("Prompt must have messages")
end
messages.each(&:symbolize_keys!)
messages.each { |message| message[:type] = message[:type].to_sym }
DiscourseAi::Completions::Prompt.new(messages: prompt["messages"])
else
raise Discourse::InvalidParameters.new("Prompt must be a string or a hash")
end
end
def llm_user
@llm_user ||=
begin
post&.user || @bot_user
end
end
def post
return @post if defined?(@post)
post_id = @context.post_id
@post = post_id && Post.find_by(id: post_id)
end
def attach_index(mini_racer_context)
mini_racer_context.attach(
"_index_search",
->(*params) do
in_attached_function do
query, options = params
self.running_attached_function = true
options ||= {}
options = options.symbolize_keys
self.rag_search(query, **options)
end
end,
)
end
def attach_chain(mini_racer_context)
mini_racer_context.attach("_chain_set_custom_raw", ->(raw) { self.custom_raw = raw })
end
def attach_discourse(mini_racer_context)
mini_racer_context.attach(
"_discourse_get_post",
->(post_id) do
in_attached_function do
post = Post.find_by(id: post_id)
return nil if post.nil?
guardian = Guardian.new(Discourse.system_user)
obj =
recursive_as_json(
PostSerializer.new(post, scope: guardian, root: false, add_raw: true),
)
topic_obj =
recursive_as_json(
ListableTopicSerializer.new(post.topic, scope: guardian, root: false),
)
obj["topic"] = topic_obj
obj
end
end,
)
mini_racer_context.attach(
"_discourse_get_topic",
->(topic_id) do
in_attached_function do
topic = Topic.find_by(id: topic_id)
return nil if topic.nil?
guardian = Guardian.new(Discourse.system_user)
recursive_as_json(ListableTopicSerializer.new(topic, scope: guardian, root: false))
end
end,
)
mini_racer_context.attach(
"_discourse_get_user",
->(user_id_or_username) do
in_attached_function do
user = nil
if user_id_or_username.is_a?(Integer) ||
user_id_or_username.to_i.to_s == user_id_or_username
user = User.find_by(id: user_id_or_username.to_i)
else
user = User.find_by(username: user_id_or_username)
end
return nil if user.nil?
guardian = Guardian.new(Discourse.system_user)
recursive_as_json(UserSerializer.new(user, scope: guardian, root: false))
end
end,
)
mini_racer_context.attach(
"_discourse_respond_to_persona",
->(persona_name, params) do
in_attached_function do
# if we have 1000s of personas this can be slow ... we may need to optimize
persona_class = AiPersona.all_personas.find { |persona| persona.name == persona_name }
return { error: "Persona not found" } if persona_class.nil?
persona = persona_class.new
bot = DiscourseAi::Personas::Bot.as(@bot_user || persona.user, persona: persona)
playground = DiscourseAi::AiBot::Playground.new(bot)
if @context.post_id
post = Post.find_by(id: @context.post_id)
return { error: "Post not found" } if post.nil?
reply_post =
playground.reply_to(
post,
custom_instructions: params["instructions"],
whisper: params["whisper"],
)
if reply_post
return(
{ success: true, post_id: reply_post.id, post_number: reply_post.post_number }
)
else
return { error: "Failed to create reply" }
end
elsif @context.message_id && @context.channel_id
message = Chat::Message.find_by(id: @context.message_id)
channel = Chat::Channel.find_by(id: @context.channel_id)
return { error: "Message or channel not found" } if message.nil? || channel.nil?
reply =
playground.reply_to_chat_message(message, channel, @context.context_post_ids)
if reply
return { success: true, message_id: reply.id }
else
return { error: "Failed to create chat reply" }
end
else
return { error: "No valid context for response" }
end
end
end,
)
mini_racer_context.attach(
"_discourse_create_chat_message",
->(params) do
in_attached_function do
params = params.symbolize_keys
channel_name = params[:channel_name]
username = params[:username]
message = params[:message]
# Validate parameters
return { error: "Missing required parameter: channel_name" } if channel_name.blank?
return { error: "Missing required parameter: username" } if username.blank?
return { error: "Missing required parameter: message" } if message.blank?
# Find the user
user = User.find_by(username: username)
return { error: "User not found: #{username}" } if user.nil?
# Find the channel
channel = Chat::Channel.find_by(name: channel_name)
if channel.nil?
# Try finding by slug if not found by name
channel = Chat::Channel.find_by(slug: channel_name.parameterize)
end
return { error: "Channel not found: #{channel_name}" } if channel.nil?
begin
guardian = Guardian.new(user)
message =
ChatSDK::Message.create(
raw: message,
channel_id: channel.id,
guardian: guardian,
enforce_membership: !channel.direct_message_channel?,
)
{
success: true,
message_id: message.id,
message: message.message,
created_at: message.created_at.iso8601,
}
rescue => e
{ error: "Failed to create chat message: #{e.message}" }
end
end
end,
)
mini_racer_context.attach(
"_discourse_search",
->(params) do
in_attached_function do
search_params = params.symbolize_keys
if search_params.delete(:with_private)
search_params[:current_user] = Discourse.system_user
end
search_params[:result_style] = :detailed
results = DiscourseAi::Utils::Search.perform_search(**search_params)
recursive_as_json(results)
end
end,
)
mini_racer_context.attach(
"_discourse_get_persona",
->(persona_name) do
in_attached_function do
persona = AiPersona.find_by(name: persona_name)
return { error: "Persona not found" } if persona.nil?
# Return a subset of relevant persona attributes
{
persona:
persona.attributes.slice(
"id",
"name",
"description",
"enabled",
"system_prompt",
"temperature",
"top_p",
"vision_enabled",
"tools",
"max_context_posts",
"allow_chat_channel_mentions",
"allow_chat_direct_messages",
"allow_topic_mentions",
"allow_personal_messages",
),
}
end
end,
)
mini_racer_context.attach(
"_discourse_update_persona",
->(persona_id_or_name, updates) do
in_attached_function do
# Find persona by ID or name
persona = nil
if persona_id_or_name.is_a?(Integer) ||
persona_id_or_name.to_i.to_s == persona_id_or_name
persona = AiPersona.find_by(id: persona_id_or_name.to_i)
else
persona = AiPersona.find_by(name: persona_id_or_name)
end
return { error: "Persona not found" } if persona.nil?
allowed_updates = {}
if updates["system_prompt"].present?
allowed_updates[:system_prompt] = updates["system_prompt"]
end
if updates["temperature"].is_a?(Numeric)
allowed_updates[:temperature] = updates["temperature"]
end
allowed_updates[:top_p] = updates["top_p"] if updates["top_p"].is_a?(Numeric)
if updates["description"].present?
allowed_updates[:description] = updates["description"]
end
allowed_updates[:enabled] = updates["enabled"] if updates["enabled"].is_a?(
TrueClass,
) || updates["enabled"].is_a?(FalseClass)
if persona.update(allowed_updates)
return(
{
success: true,
persona:
persona.attributes.slice(
"id",
"name",
"description",
"enabled",
"system_prompt",
"temperature",
"top_p",
),
}
)
else
return { error: persona.errors.full_messages.join(", ") }
end
end
end,
)
end
def attach_upload(mini_racer_context)
mini_racer_context.attach(
"_upload_create",
->(filename, base_64_content) do
begin
in_attached_function do
# protect against misuse
filename = File.basename(filename)
Tempfile.create(filename) do |file|
file.binmode
file.write(Base64.decode64(base_64_content))
file.rewind
upload =
UploadCreator.new(
file,
filename,
for_private_message: @context.private_message,
).create_for(@bot_user.id)
{ id: upload.id, short_url: upload.short_url, url: upload.url }
end
end
end
end,
)
end
def attach_http(mini_racer_context)
mini_racer_context.attach(
"_http_get",
->(url, options) do
begin
@http_requests_made += 1
if @http_requests_made > MAX_HTTP_REQUESTS
raise TooManyRequestsError.new("Tool made too many HTTP requests")
end
in_attached_function do
headers = (options && options["headers"]) || {}
result = {}
DiscourseAi::Personas::Tools::Tool.send_http_request(
url,
headers: headers,
) do |response|
result[:body] = response.body
result[:status] = response.code.to_i
end
result
end
end
end,
)
%i[post put patch delete].each do |method|
mini_racer_context.attach(
"_http_#{method}",
->(url, options) do
begin
@http_requests_made += 1
if @http_requests_made > MAX_HTTP_REQUESTS
raise TooManyRequestsError.new("Tool made too many HTTP requests")
end
in_attached_function do
headers = (options && options["headers"]) || {}
body = options && options["body"]
result = {}
DiscourseAi::Personas::Tools::Tool.send_http_request(
url,
method: method,
headers: headers,
body: body,
) do |response|
result[:body] = response.body
result[:status] = response.code.to_i
end
result
rescue => e
if Rails.env.development?
p url
p options
p e
puts e.backtrace
end
raise e
end
end
end,
)
end
end
def in_attached_function
self.running_attached_function = true
yield
ensure
self.running_attached_function = false
end
def recursive_as_json(obj)
case obj
when Array
obj.map { |item| recursive_as_json(item) }
when Hash
obj.transform_values { |value| recursive_as_json(value) }
when ActiveModel::Serializer, ActiveModel::ArraySerializer
recursive_as_json(obj.as_json)
when ActiveRecord::Base
recursive_as_json(obj.as_json)
else
# Handle objects that respond to as_json but aren't handled above
if obj.respond_to?(:as_json)
result = obj.as_json
if result.equal?(obj)
# If as_json returned the same object, return it to avoid infinite recursion
result
else
recursive_as_json(result)
end
else
# Primitive values like strings, numbers, booleans, nil
obj
end
end
end
end
end
end