diff --git a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js
index b06c525c..d5dea6f4 100644
--- a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js
+++ b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js
@@ -2,7 +2,7 @@ import { later } from "@ember/runloop";
import PostUpdater from "./updaters/post-updater";
const PROGRESS_INTERVAL = 40;
-const GIVE_UP_INTERVAL = 60000;
+const GIVE_UP_INTERVAL = 600000; // 10 minutes which is our max thinking time for now
export const MIN_LETTERS_PER_INTERVAL = 6;
const MAX_FLUSH_TIME = 800;
diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml
index 00a71cd3..0875eba9 100644
--- a/config/locales/server.en.yml
+++ b/config/locales/server.en.yml
@@ -296,6 +296,9 @@ en:
designer:
name: Designer
description: "AI Bot specialized in generating and editing images"
+ forum_researcher:
+ name: Forum Researcher
+ description: "AI Bot specialized in deep research for the forum"
sql_helper:
name: SQL Helper
description: "AI Bot specialized in helping craft SQL queries on this Discourse instance"
@@ -303,8 +306,8 @@ en:
name: Settings Explorer
description: "AI Bot specialized in helping explore Discourse site settings"
researcher:
- name: Researcher
- description: "AI Bot with Google access that can research information for you"
+ name: Web Researcher
+ description: "AI Bot with Google access that can both search and read web pages"
creative:
name: Creative
description: "AI Bot with no external integrations specialized in creative tasks"
@@ -327,6 +330,16 @@ en:
summarizing: "Summarizing topic"
searching: "Searching for: '%{query}'"
tool_options:
+ researcher:
+ max_results:
+ name: "Maximum number of results"
+ description: "Maximum number of results to include in a filter"
+ include_private:
+ name: "Include private"
+ description: "Include private topics in the filters"
+ max_tokens_per_post:
+ name: "Maximum tokens per post"
+ description: "Maximum number of tokens to use for each post in the filter"
create_artifact:
creator_llm:
name: "LLM"
@@ -385,6 +398,7 @@ en:
javascript_evaluator: "Evaluate JavaScript"
create_image: "Creating image"
edit_image: "Editing image"
+ researcher: "Researching"
tool_help:
read_artifact: "Read a web artifact using the AI Bot"
update_artifact: "Update a web artifact using the AI Bot"
@@ -411,6 +425,7 @@ en:
dall_e: "Generate image using DALL-E 3"
search_meta_discourse: "Search Meta Discourse"
javascript_evaluator: "Evaluate JavaScript"
+ researcher: "Research forum information using the AI Bot"
tool_description:
read_artifact: "Read a web artifact using the AI Bot"
update_artifact: "Updated a web artifact using the AI Bot"
@@ -445,6 +460,12 @@ en:
other: "Found %{count} results for '%{query}'"
setting_context: "Reading context for: %{setting_name}"
schema: "%{tables}"
+ researcher_dry_run:
+ one: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'"
+ other: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'"
+ researcher:
+ one: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'"
+ other: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'"
search_settings:
one: "Found %{count} result for '%{query}'"
other: "Found %{count} results for '%{query}'"
diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb
index 27e6d479..c2c121de 100644
--- a/db/fixtures/personas/603_ai_personas.rb
+++ b/db/fixtures/personas/603_ai_personas.rb
@@ -33,7 +33,7 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id|
persona.allowed_group_ids = [Group::AUTO_GROUPS[:trust_level_0]]
end
- persona.enabled = !summarization_personas.include?(persona_class)
+ persona.enabled = persona_class.default_enabled
persona.priority = true if persona_class == DiscourseAi::Personas::General
end
diff --git a/lib/ai_bot/chat_streamer.rb b/lib/ai_bot/chat_streamer.rb
index 139a6c7f..06357e0e 100644
--- a/lib/ai_bot/chat_streamer.rb
+++ b/lib/ai_bot/chat_streamer.rb
@@ -6,16 +6,23 @@
module DiscourseAi
module AiBot
class ChatStreamer
- attr_accessor :cancel
attr_reader :reply,
:guardian,
:thread_id,
:force_thread,
:in_reply_to_id,
:channel,
- :cancelled
+ :cancel_manager
- def initialize(message:, channel:, guardian:, thread_id:, in_reply_to_id:, force_thread:)
+ def initialize(
+ message:,
+ channel:,
+ guardian:,
+ thread_id:,
+ in_reply_to_id:,
+ force_thread:,
+ cancel_manager: nil
+ )
@message = message
@channel = channel
@guardian = guardian
@@ -35,6 +42,8 @@ module DiscourseAi
guardian: guardian,
thread_id: thread_id,
)
+
+ @cancel_manager = cancel_manager
end
def <<(partial)
@@ -111,8 +120,7 @@ module DiscourseAi
streaming = ChatSDK::Message.stream(message_id: reply.id, raw: buffer, guardian: guardian)
if !streaming
- cancel.call
- @cancelled = true
+ @cancel_manager.cancel! if @cancel_manager
end
end
end
diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb
index 63aa8967..07c4984b 100644
--- a/lib/ai_bot/playground.rb
+++ b/lib/ai_bot/playground.rb
@@ -331,6 +331,7 @@ module DiscourseAi
),
user: message.user,
skip_tool_details: true,
+ cancel_manager: DiscourseAi::Completions::CancelManager.new,
)
reply = nil
@@ -347,15 +348,14 @@ module DiscourseAi
thread_id: message.thread_id,
in_reply_to_id: in_reply_to_id,
force_thread: force_thread,
+ cancel_manager: context.cancel_manager,
)
new_prompts =
- bot.reply(context) do |partial, cancel, placeholder, type|
+ bot.reply(context) do |partial, placeholder, type|
# no support for tools or thinking by design
next if type == :thinking || type == :tool_details || type == :partial_tool
- streamer.cancel = cancel
streamer << partial
- break if streamer.cancelled
end
reply = streamer.reply
@@ -383,6 +383,7 @@ module DiscourseAi
auto_set_title: true,
silent_mode: false,
feature_name: nil,
+ cancel_manager: nil,
&blk
)
# this is a multithreading issue
@@ -471,16 +472,26 @@ module DiscourseAi
redis_stream_key = "gpt_cancel:#{reply_post.id}"
Discourse.redis.setex(redis_stream_key, MAX_STREAM_DELAY_SECONDS, 1)
+
+ cancel_manager ||= DiscourseAi::Completions::CancelManager.new
+ context.cancel_manager = cancel_manager
+ context
+ .cancel_manager
+ .start_monitor(delay: 0.2) do
+ context.cancel_manager.cancel! if !Discourse.redis.get(redis_stream_key)
+ end
+
+ context.cancel_manager.add_callback(
+ lambda { reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) },
+ )
end
context.skip_tool_details ||= !bot.persona.class.tool_details
-
post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply
-
started_thinking = false
new_custom_prompts =
- bot.reply(context) do |partial, cancel, placeholder, type|
+ bot.reply(context) do |partial, placeholder, type|
if type == :thinking && !started_thinking
reply << "#{I18n.t("discourse_ai.ai_bot.thinking")}
"
started_thinking = true
@@ -499,15 +510,6 @@ module DiscourseAi
blk.call(partial)
end
- if stream_reply && !Discourse.redis.get(redis_stream_key)
- cancel&.call
- reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
- # we do not break out, cause if we do
- # we will not get results from bot
- # leading to broken context
- # we need to trust it to cancel at the endpoint
- end
-
if post_streamer
post_streamer.run_later do
Discourse.redis.expire(redis_stream_key, MAX_STREAM_DELAY_SECONDS)
@@ -568,6 +570,8 @@ module DiscourseAi
end
raise e
ensure
+ context.cancel_manager.stop_monitor if context&.cancel_manager
+
# since we are skipping validations and jobs we
# may need to fix participant count
if reply_post && reply_post.topic && reply_post.topic.private_message? &&
@@ -649,7 +653,7 @@ module DiscourseAi
payload,
user_ids: bot_reply_post.topic.allowed_user_ids,
max_backlog_size: 2,
- max_backlog_age: 60,
+ max_backlog_age: MAX_STREAM_DELAY_SECONDS,
)
end
end
diff --git a/lib/completions/cancel_manager.rb b/lib/completions/cancel_manager.rb
new file mode 100644
index 00000000..78c3ee5b
--- /dev/null
+++ b/lib/completions/cancel_manager.rb
@@ -0,0 +1,109 @@
+# frozen_string_literal: true
+
+# special object that can be used to cancel completions and http requests
+module DiscourseAi
+ module Completions
+ class CancelManager
+ attr_reader :cancelled
+ attr_reader :callbacks
+
+ def initialize
+ @cancelled = false
+ @callbacks = Concurrent::Array.new
+ @mutex = Mutex.new
+ @monitor_thread = nil
+ end
+
+ def monitor_thread
+ @mutex.synchronize { @monitor_thread }
+ end
+
+ def start_monitor(delay: 0.5, &block)
+ @mutex.synchronize do
+ raise "Already monitoring" if @monitor_thread
+ raise "Expected a block" if !block
+
+ db = RailsMultisite::ConnectionManagement.current_db
+ @stop_monitor = false
+
+ @monitor_thread =
+ Thread.new do
+ begin
+ loop do
+ done = false
+ @mutex.synchronize { done = true if @stop_monitor }
+ break if done
+ sleep delay
+ @mutex.synchronize { done = true if @stop_monitor }
+ @mutex.synchronize { done = true if cancelled? }
+ break if done
+
+ should_cancel = false
+ RailsMultisite::ConnectionManagement.with_connection(db) do
+ should_cancel = block.call
+ end
+
+ @mutex.synchronize { cancel! if should_cancel }
+
+ break if cancelled?
+ end
+ ensure
+ @mutex.synchronize { @monitor_thread = nil }
+ end
+ end
+ end
+ end
+
+ def stop_monitor
+ monitor_thread = nil
+
+ @mutex.synchronize { monitor_thread = @monitor_thread }
+
+ if monitor_thread
+ @mutex.synchronize { @stop_monitor = true }
+ # so we do not deadlock
+ monitor_thread.wakeup
+ monitor_thread.join(2)
+ # should not happen
+ if monitor_thread.alive?
+ Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time")
+ monitor_thread.kill if monitor_thread.alive?
+ end
+ @monitor_thread = nil
+ end
+ end
+
+ def cancelled?
+ @cancelled
+ end
+
+ def add_callback(cb)
+ @callbacks << cb
+ end
+
+ def remove_callback(cb)
+ @callbacks.delete(cb)
+ end
+
+ def cancel!
+ @cancelled = true
+ monitor_thread = @monitor_thread
+ if monitor_thread && monitor_thread != Thread.current
+ monitor_thread.wakeup
+ monitor_thread.join(2)
+ if monitor_thread.alive?
+ Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time")
+ monitor_thread.kill if monitor_thread.alive?
+ end
+ end
+ @callbacks.each do |cb|
+ begin
+ cb.call
+ rescue StandardError
+ # ignore cause this may have already been cancelled
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 3a12c4c7..0356597e 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -68,11 +68,17 @@ module DiscourseAi
feature_context: nil,
partial_tool_calls: false,
output_thinking: false,
+ cancel_manager: nil,
&blk
)
LlmQuota.check_quotas!(@llm_model, user)
start_time = Time.now
+ if cancel_manager && cancel_manager.cancelled?
+ # nothing to do
+ return
+ end
+
@forced_json_through_prefill = false
@partial_tool_calls = partial_tool_calls
@output_thinking = output_thinking
@@ -90,15 +96,14 @@ module DiscourseAi
feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
output_thinking: output_thinking,
+ cancel_manager: cancel_manager,
)
wrapped = result
wrapped = [result] if !result.is_a?(Array)
- cancelled_by_caller = false
- cancel_proc = -> { cancelled_by_caller = true }
wrapped.each do |partial|
- blk.call(partial, cancel_proc)
- break if cancelled_by_caller
+ blk.call(partial)
+ break cancel_manager&.cancelled?
end
return result
end
@@ -118,6 +123,9 @@ module DiscourseAi
end
end
+ cancel_manager_callback = nil
+ cancelled = false
+
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
@@ -126,6 +134,14 @@ module DiscourseAi
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
+ if cancel_manager
+ cancel_manager_callback =
+ lambda do
+ cancelled = true
+ http.finish
+ end
+ cancel_manager.add_callback(cancel_manager_callback)
+ end
response_data = +""
response_raw = +""
@@ -158,7 +174,7 @@ module DiscourseAi
if @streaming_mode
blk =
- lambda do |partial, cancel|
+ lambda do |partial|
if partial.is_a?(String)
partial = xml_stripper << partial if xml_stripper
@@ -167,7 +183,7 @@ module DiscourseAi
partial = structured_output
end
end
- orig_blk.call(partial, cancel) if partial
+ orig_blk.call(partial) if partial
end
end
@@ -196,14 +212,6 @@ module DiscourseAi
end
begin
- cancelled = false
- cancel = -> do
- cancelled = true
- http.finish
- end
-
- break if cancelled
-
response.read_body do |chunk|
break if cancelled
@@ -216,16 +224,11 @@ module DiscourseAi
partials = [partial]
if xml_tool_processor && partial.is_a?(String)
partials = (xml_tool_processor << partial)
- if xml_tool_processor.should_cancel?
- cancel.call
- break
- end
+ break if xml_tool_processor.should_cancel?
end
- partials.each { |inner_partial| blk.call(inner_partial, cancel) }
+ partials.each { |inner_partial| blk.call(inner_partial) }
end
end
- rescue IOError, StandardError
- raise if !cancelled
end
if xml_stripper
stripped = xml_stripper.finish
@@ -233,13 +236,11 @@ module DiscourseAi
response_data << stripped
result = []
result = (xml_tool_processor << stripped) if xml_tool_processor
- result.each { |partial| blk.call(partial, cancel) }
+ result.each { |partial| blk.call(partial) }
end
end
- if xml_tool_processor
- xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
- end
- decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
+ xml_tool_processor.finish.each { |partial| blk.call(partial) } if xml_tool_processor
+ decode_chunk_finish.each { |partial| blk.call(partial) }
return response_data
ensure
if log
@@ -293,6 +294,12 @@ module DiscourseAi
end
end
end
+ rescue IOError, StandardError
+ raise if !cancelled
+ ensure
+ if cancel_manager && cancel_manager_callback
+ cancel_manager.remove_callback(cancel_manager_callback)
+ end
end
def final_log_update(log)
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index be156aea..b31b2bb4 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -30,7 +30,8 @@ module DiscourseAi
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
- output_thinking: false
+ output_thinking: false,
+ cancel_manager: nil
)
@dialect = dialect
@model_params = model_params
diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb
index d307ada5..71d0ee4b 100644
--- a/lib/completions/endpoints/fake.rb
+++ b/lib/completions/endpoints/fake.rb
@@ -122,7 +122,8 @@ module DiscourseAi
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
- output_thinking: false
+ output_thinking: false,
+ cancel_manager: nil
)
last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index f97a337d..070d391c 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -46,6 +46,7 @@ module DiscourseAi
feature_context: nil,
partial_tool_calls: false,
output_thinking: false,
+ cancel_manager: nil,
&blk
)
@disable_native_tools = dialect.disable_native_tools?
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index f90de091..04d517cb 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -307,7 +307,7 @@ module DiscourseAi
# @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output.
# @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
#
- # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
+ # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response.
#
# @returns String | ToolCall - Completion result.
# if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects.
@@ -325,6 +325,7 @@ module DiscourseAi
output_thinking: false,
response_format: nil,
extra_model_params: nil,
+ cancel_manager: nil,
&partial_read_blk
)
self.class.record_prompt(
@@ -378,6 +379,7 @@ module DiscourseAi
feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
output_thinking: output_thinking,
+ cancel_manager: cancel_manager,
&partial_read_blk
)
end
diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb
index 4d302ea1..2864846a 100644
--- a/lib/completions/prompt_messages_builder.rb
+++ b/lib/completions/prompt_messages_builder.rb
@@ -247,6 +247,10 @@ module DiscourseAi
# 3. ensures we always interleave user and model messages
last_type = nil
messages.each do |message|
+ if message[:type] == :model && !message[:content]
+ message[:content] = "Reply cancelled by user."
+ end
+
next if !last_type && message[:type] != :user
if last_type == :tool_call && message[:type] != :tool
diff --git a/lib/discord/bot/persona_replier.rb b/lib/discord/bot/persona_replier.rb
index 51e03475..b64af15c 100644
--- a/lib/discord/bot/persona_replier.rb
+++ b/lib/discord/bot/persona_replier.rb
@@ -24,7 +24,7 @@ module DiscourseAi
full_reply =
@bot.reply(
{ conversation_context: [{ type: :user, content: @query }], skip_tool_details: true },
- ) do |partial, _cancel, _something|
+ ) do |partial, _something|
reply << partial
next if reply.blank?
diff --git a/lib/inference/open_ai_image_generator.rb b/lib/inference/open_ai_image_generator.rb
index f4ff6754..83501989 100644
--- a/lib/inference/open_ai_image_generator.rb
+++ b/lib/inference/open_ai_image_generator.rb
@@ -21,7 +21,8 @@ module ::DiscourseAi
moderation: "low",
output_compression: nil,
output_format: nil,
- title: nil
+ title: nil,
+ cancel_manager: nil
)
# Get the API responses in parallel threads
api_responses =
@@ -38,6 +39,7 @@ module ::DiscourseAi
moderation: moderation,
output_compression: output_compression,
output_format: output_format,
+ cancel_manager: cancel_manager,
)
raise api_responses[0] if api_responses.all? { |resp| resp.is_a?(StandardError) }
@@ -58,7 +60,8 @@ module ::DiscourseAi
user_id:,
for_private_message: false,
n: 1,
- quality: nil
+ quality: nil,
+ cancel_manager: nil
)
api_response =
edit_images(
@@ -70,6 +73,7 @@ module ::DiscourseAi
api_url: api_url,
n: n,
quality: quality,
+ cancel_manager: cancel_manager,
)
create_uploads_from_responses([api_response], user_id, for_private_message).first
@@ -124,7 +128,8 @@ module ::DiscourseAi
background:,
moderation:,
output_compression:,
- output_format:
+ output_format:,
+ cancel_manager:
)
prompts = [prompts] unless prompts.is_a?(Array)
prompts = prompts.take(4) # Limit to 4 prompts max
@@ -152,18 +157,21 @@ module ::DiscourseAi
moderation: moderation,
output_compression: output_compression,
output_format: output_format,
+ cancel_manager: cancel_manager,
)
rescue => e
attempts += 1
# to keep tests speedy
- if !Rails.env.test?
+ if !Rails.env.test? && !cancel_manager&.cancelled?
retry if attempts < 3
end
- Discourse.warn_exception(
- e,
- message: "Failed to generate image for prompt #{prompt}\n",
- )
- puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development?
+ if !cancel_manager&.cancelled?
+ Discourse.warn_exception(
+ e,
+ message: "Failed to generate image for prompt #{prompt}\n",
+ )
+ puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development?
+ end
e
end
end
@@ -181,7 +189,8 @@ module ::DiscourseAi
api_key: nil,
api_url: nil,
n: 1,
- quality: nil
+ quality: nil,
+ cancel_manager: nil
)
images = [images] if !images.is_a?(Array)
@@ -209,8 +218,10 @@ module ::DiscourseAi
api_url: api_url,
n: n,
quality: quality,
+ cancel_manager: cancel_manager,
)
rescue => e
+ raise e if cancel_manager&.cancelled?
attempts += 1
if !Rails.env.test?
sleep 2
@@ -238,7 +249,8 @@ module ::DiscourseAi
background: nil,
moderation: nil,
output_compression: nil,
- output_format: nil
+ output_format: nil,
+ cancel_manager: nil
)
api_key ||= SiteSetting.ai_openai_api_key
api_url ||= SiteSetting.ai_openai_image_generation_url
@@ -276,6 +288,7 @@ module ::DiscourseAi
# Store original prompt for upload metadata
original_prompt = prompt
+ cancel_manager_callback = nil
FinalDestination::HTTP.start(
uri.host,
@@ -288,6 +301,11 @@ module ::DiscourseAi
request = Net::HTTP::Post.new(uri, headers)
request.body = payload.to_json
+ if cancel_manager
+ cancel_manager_callback = lambda { http.finish }
+ cancel_manager.add_callback(cancel_manager_callback)
+ end
+
json = nil
http.request(request) do |response|
if response.code.to_i != 200
@@ -300,6 +318,10 @@ module ::DiscourseAi
end
json
end
+ ensure
+ if cancel_manager && cancel_manager_callback
+ cancel_manager.remove_callback(cancel_manager_callback)
+ end
end
def self.perform_edit_api_call!(
@@ -310,7 +332,8 @@ module ::DiscourseAi
api_key:,
api_url:,
n: 1,
- quality: nil
+ quality: nil,
+ cancel_manager: nil
)
uri = URI(api_url)
@@ -403,6 +426,7 @@ module ::DiscourseAi
# Store original prompt for upload metadata
original_prompt = prompt
+ cancel_manager_callback = nil
FinalDestination::HTTP.start(
uri.host,
@@ -415,6 +439,11 @@ module ::DiscourseAi
request = Net::HTTP::Post.new(uri.path, headers)
request.body = body.join
+ if cancel_manager
+ cancel_manager_callback = lambda { http.finish }
+ cancel_manager.add_callback(cancel_manager_callback)
+ end
+
json = nil
http.request(request) do |response|
if response.code.to_i != 200
@@ -428,6 +457,9 @@ module ::DiscourseAi
json
end
ensure
+ if cancel_manager && cancel_manager_callback
+ cancel_manager.remove_callback(cancel_manager_callback)
+ end
if files_to_delete.present?
files_to_delete.each { |file| File.delete(file) if File.exist?(file) }
end
diff --git a/lib/personas/artifact_update_strategies/base.rb b/lib/personas/artifact_update_strategies/base.rb
index 68e0b761..624830f1 100644
--- a/lib/personas/artifact_update_strategies/base.rb
+++ b/lib/personas/artifact_update_strategies/base.rb
@@ -5,15 +5,24 @@ module DiscourseAi
class InvalidFormatError < StandardError
end
class Base
- attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm
+ attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm, :cancel_manager
- def initialize(llm:, post:, user:, artifact:, artifact_version:, instructions:)
+ def initialize(
+ llm:,
+ post:,
+ user:,
+ artifact:,
+ artifact_version:,
+ instructions:,
+ cancel_manager: nil
+ )
@llm = llm
@post = post
@user = user
@artifact = artifact
@artifact_version = artifact_version
@instructions = instructions
+ @cancel_manager = cancel_manager
end
def apply(&progress)
@@ -26,7 +35,7 @@ module DiscourseAi
def generate_changes(&progress)
response = +""
- llm.generate(build_prompt, user: user) do |partial|
+ llm.generate(build_prompt, user: user, cancel_manager: cancel_manager) do |partial|
progress.call(partial) if progress
response << partial
end
diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb
index 2c9b5a3b..3686edac 100644
--- a/lib/personas/bot.rb
+++ b/lib/personas/bot.rb
@@ -55,6 +55,7 @@ module DiscourseAi
unless context.is_a?(BotContext)
raise ArgumentError, "context must be an instance of BotContext"
end
+ context.cancel_manager ||= DiscourseAi::Completions::CancelManager.new
current_llm = llm
prompt = persona.craft_prompt(context, llm: current_llm)
@@ -91,8 +92,9 @@ module DiscourseAi
feature_name: context.feature_name,
partial_tool_calls: allow_partial_tool_calls,
output_thinking: true,
+ cancel_manager: context.cancel_manager,
**llm_kwargs,
- ) do |partial, cancel|
+ ) do |partial|
tool =
persona.find_tool(
partial,
@@ -109,7 +111,7 @@ module DiscourseAi
if tool_call.partial?
if tool.class.allow_partial_tool_calls?
tool.partial_invoke
- update_blk.call("", cancel, tool.custom_raw, :partial_tool)
+ update_blk.call("", tool.custom_raw, :partial_tool)
end
next
end
@@ -117,7 +119,7 @@ module DiscourseAi
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
- update_blk.call("\n\n", cancel)
+ update_blk.call("\n\n")
needs_newlines = false
end
@@ -125,7 +127,6 @@ module DiscourseAi
tool: tool,
raw_context: raw_context,
current_llm: current_llm,
- cancel: cancel,
update_blk: update_blk,
prompt: prompt,
context: context,
@@ -144,7 +145,7 @@ module DiscourseAi
else
if partial.is_a?(DiscourseAi::Completions::Thinking)
if partial.partial? && partial.message.present?
- update_blk.call(partial.message, cancel, nil, :thinking)
+ update_blk.call(partial.message, nil, :thinking)
end
if !partial.partial?
# this will be dealt with later
@@ -152,9 +153,9 @@ module DiscourseAi
current_thinking << partial
end
elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput)
- update_blk.call(partial, cancel, nil, :structured_output)
+ update_blk.call(partial, nil, :structured_output)
else
- update_blk.call(partial, cancel)
+ update_blk.call(partial)
end
end
end
@@ -215,14 +216,13 @@ module DiscourseAi
tool:,
raw_context:,
current_llm:,
- cancel:,
update_blk:,
prompt:,
context:,
current_thinking:
)
tool_call_id = tool.tool_call_id
- invocation_result_json = invoke_tool(tool, cancel, context, &update_blk).to_json
+ invocation_result_json = invoke_tool(tool, context, &update_blk).to_json
tool_call_message = {
type: :tool_call,
@@ -256,27 +256,27 @@ module DiscourseAi
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
end
- def invoke_tool(tool, cancel, context, &update_blk)
+ def invoke_tool(tool, context, &update_blk)
show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls?
- update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
+ update_blk.call("", build_placeholder(tool.summary, "")) if show_placeholder
result =
tool.invoke do |progress, render_raw|
if render_raw
- update_blk.call("", cancel, tool.custom_raw, :partial_invoke)
+ update_blk.call("", tool.custom_raw, :partial_invoke)
show_placeholder = false
elsif show_placeholder
placeholder = build_placeholder(tool.summary, progress)
- update_blk.call("", cancel, placeholder)
+ update_blk.call("", placeholder)
end
end
if show_placeholder
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
- update_blk.call(tool_details, cancel, nil, :tool_details)
+ update_blk.call(tool_details, nil, :tool_details)
elsif tool.custom_raw.present?
- update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
+ update_blk.call(tool.custom_raw, nil, :custom_raw)
end
result
diff --git a/lib/personas/bot_context.rb b/lib/personas/bot_context.rb
index 5f7dd99e..69d86669 100644
--- a/lib/personas/bot_context.rb
+++ b/lib/personas/bot_context.rb
@@ -16,7 +16,8 @@ module DiscourseAi
:channel_id,
:context_post_ids,
:feature_name,
- :resource_url
+ :resource_url,
+ :cancel_manager
def initialize(
post: nil,
@@ -33,7 +34,8 @@ module DiscourseAi
channel_id: nil,
context_post_ids: nil,
feature_name: "bot",
- resource_url: nil
+ resource_url: nil,
+ cancel_manager: nil
)
@participants = participants
@user = user
@@ -54,6 +56,8 @@ module DiscourseAi
@feature_name = feature_name
@resource_url = resource_url
+ @cancel_manager = cancel_manager
+
if post
@post_id = post.id
@topic_id = post.topic_id
diff --git a/lib/personas/forum_researcher.rb b/lib/personas/forum_researcher.rb
new file mode 100644
index 00000000..eb0f82cb
--- /dev/null
+++ b/lib/personas/forum_researcher.rb
@@ -0,0 +1,52 @@
+#frozen_string_literal: true
+
+module DiscourseAi
+ module Personas
+ class ForumResearcher < Persona
+ def self.default_enabled
+ false
+ end
+
+ def tools
+ [Tools::Researcher]
+ end
+
+ def system_prompt
+ <<~PROMPT
+ You are a helpful Discourse assistant specializing in forum research.
+ You _understand_ and **generate** Discourse Markdown.
+
+ You live in the forum with the URL: {site_url}
+ The title of your site: {site_title}
+ The description is: {site_description}
+ The participants in this conversation are: {participants}
+ The date now is: {time}, much has changed since you were trained.
+
+ As a forum researcher, guide users through a structured research process:
+ 1. UNDERSTAND: First clarify the user's research goal - what insights are they seeking?
+ 2. PLAN: Design an appropriate research approach with specific filters
+ 3. TEST: Always begin with dry_run:true to gauge the scope of results
+ 4. REFINE: If results are too broad/narrow, suggest filter adjustments
+ 5. EXECUTE: Run the final analysis only when filters are well-tuned
+ 6. SUMMARIZE: Present findings with links to supporting evidence
+
+ BE MINDFUL: specify all research goals in one request to avoid multiple processing runs.
+
+ REMEMBER: Different filters serve different purposes:
+ - Use post date filters (after/before) for analyzing specific posts
+ - Use topic date filters (topic_after/topic_before) for analyzing entire topics
+ - Combine user/group filters with categories/tags to find specialized contributions
+
+ Always ground your analysis with links to original posts on the forum.
+
+ Research workflow best practices:
+ 1. Start with a dry_run to gauge the scope (set dry_run:true)
+ 2. If results are too numerous (>1000), add more specific filters
+ 3. If results are too few (<5), broaden your filters
+ 4. For temporal analysis, specify explicit date ranges
+ 5. For user behavior analysis, combine @username with categories or tags
+ PROMPT
+ end
+ end
+ end
+end
diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb
index a4443dcc..62426f77 100644
--- a/lib/personas/persona.rb
+++ b/lib/personas/persona.rb
@@ -4,6 +4,10 @@ module DiscourseAi
module Personas
class Persona
class << self
+ def default_enabled
+ true
+ end
+
def rag_conversation_chunks
10
end
@@ -47,6 +51,7 @@ module DiscourseAi
Summarizer => -11,
ShortSummarizer => -12,
Designer => -13,
+ ForumResearcher => -14,
}
end
@@ -99,6 +104,7 @@ module DiscourseAi
Tools::GithubSearchFiles,
Tools::WebBrowser,
Tools::JavascriptEvaluator,
+ Tools::Researcher,
]
if SiteSetting.ai_artifact_security.in?(%w[lax strict])
diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb
index e7cef54a..5b9b5195 100644
--- a/lib/personas/short_summarizer.rb
+++ b/lib/personas/short_summarizer.rb
@@ -3,6 +3,10 @@
module DiscourseAi
module Personas
class ShortSummarizer < Persona
+ def self.default_enabled
+ false
+ end
+
def system_prompt
<<~PROMPT.strip
You are an advanced summarization bot. Analyze a given conversation and produce a concise,
@@ -23,7 +27,7 @@ module DiscourseAi
-
+
Where "xx" is replaced by the summary.
PROMPT
end
diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb
index c1eefe89..b8b0a95a 100644
--- a/lib/personas/summarizer.rb
+++ b/lib/personas/summarizer.rb
@@ -3,6 +3,10 @@
module DiscourseAi
module Personas
class Summarizer < Persona
+ def self.default_enabled
+ false
+ end
+
def system_prompt
<<~PROMPT.strip
You are an advanced summarization bot that generates concise, coherent summaries of provided text.
@@ -18,13 +22,13 @@ module DiscourseAi
- Example: link to the 6th post by jane: [agreed with]({resource_url}/6)
- Example: link to the 13th post by joe: [joe]({resource_url}/13)
- When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER)
-
+
Format your response as a JSON object with a single key named "summary", which has the summary as the value.
Your output should be in the following format:
-
+
Where "xx" is replaced by the summary.
PROMPT
end
diff --git a/lib/personas/tools/create_artifact.rb b/lib/personas/tools/create_artifact.rb
index c12577ff..548fb9aa 100644
--- a/lib/personas/tools/create_artifact.rb
+++ b/lib/personas/tools/create_artifact.rb
@@ -151,7 +151,12 @@ module DiscourseAi
LlmModel.find_by(id: options[:creator_llm].to_i)&.to_llm
) || self.llm
- llm.generate(prompt, user: user, feature_name: "create_artifact") do |partial_response|
+ llm.generate(
+ prompt,
+ user: user,
+ feature_name: "create_artifact",
+ cancel_manager: context.cancel_manager,
+ ) do |partial_response|
response << partial_response
yield partial_response
end
diff --git a/lib/personas/tools/create_image.rb b/lib/personas/tools/create_image.rb
index f57c83ff..8e2971fa 100644
--- a/lib/personas/tools/create_image.rb
+++ b/lib/personas/tools/create_image.rb
@@ -48,6 +48,7 @@ module DiscourseAi
max_prompts,
model: "gpt-image-1",
user_id: bot_user.id,
+ cancel_manager: context.cancel_manager,
)
rescue => e
@error = e
diff --git a/lib/personas/tools/edit_image.rb b/lib/personas/tools/edit_image.rb
index b28afd0c..b9e3249a 100644
--- a/lib/personas/tools/edit_image.rb
+++ b/lib/personas/tools/edit_image.rb
@@ -60,6 +60,7 @@ module DiscourseAi
uploads,
prompt,
user_id: bot_user.id,
+ cancel_manager: context.cancel_manager,
)
rescue => e
@error = e
diff --git a/lib/personas/tools/researcher.rb b/lib/personas/tools/researcher.rb
new file mode 100644
index 00000000..3ab4c8e0
--- /dev/null
+++ b/lib/personas/tools/researcher.rb
@@ -0,0 +1,181 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Personas
+ module Tools
+ class Researcher < Tool
+ attr_reader :filter, :result_count, :goals, :dry_run
+
+ class << self
+ def signature
+ {
+ name: name,
+ description:
+ "Analyze and extract information from content across the forum based on specified filters",
+ parameters: [
+ { name: "filter", description: filter_description, type: "string" },
+ {
+ name: "goals",
+ description:
+ "The specific information you want to extract or analyze from the filtered content, you may specify multiple goals",
+ type: "string",
+ },
+ {
+ name: "dry_run",
+ description: "When true, only count matching items without processing data",
+ type: "boolean",
+ },
+ ],
+ }
+ end
+
+ def filter_description
+ <<~TEXT
+ Filter string to target specific content.
+ - Supports user (@username)
+ - date ranges (after:YYYY-MM-DD, before:YYYY-MM-DD for posts; topic_after:YYYY-MM-DD, topic_before:YYYY-MM-DD for topics)
+ - categories (category:category1,category2)
+ - tags (tag:tag1,tag2)
+ - groups (group:group1,group2).
+ - status (status:open, status:closed, status:archived, status:noreplies, status:single_user)
+ - keywords (keywords:keyword1,keyword2) - specific words to search for in posts
+ - max_results (max_results:10) the maximum number of results to return (optional)
+ - order (order:latest, order:oldest, order:latest_topic, order:oldest_topic) - the order of the results (optional)
+
+ If multiple tags or categories are specified, they are treated as OR conditions.
+
+ Multiple filters can be combined with spaces. Example: '@sam after:2023-01-01 tag:feature'
+ TEXT
+ end
+
+ def name
+ "researcher"
+ end
+
+ def accepted_options
+ [
+ option(:max_results, type: :integer),
+ option(:include_private, type: :boolean),
+ option(:max_tokens_per_post, type: :integer),
+ ]
+ end
+ end
+
+ def invoke(&blk)
+ max_results = options[:max_results] || 1000
+
+ @filter = parameters[:filter] || ""
+ @goals = parameters[:goals] || ""
+ @dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run]
+
+ post = Post.find_by(id: context.post_id)
+ goals = parameters[:goals] || ""
+ dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run]
+
+ return { error: "No goals provided" } if goals.blank?
+ return { error: "No filter provided" } if @filter.blank?
+
+ guardian = nil
+ guardian = Guardian.new(context.user) if options[:include_private]
+
+ filter =
+ DiscourseAi::Utils::Research::Filter.new(
+ @filter,
+ limit: max_results,
+ guardian: guardian,
+ )
+ @result_count = filter.search.count
+
+ blk.call details
+
+ if dry_run
+ { dry_run: true, goals: goals, filter: @filter, number_of_results: @result_count }
+ else
+ process_filter(filter, goals, post, &blk)
+ end
+ end
+
+ def details
+ if @dry_run
+ I18n.t("discourse_ai.ai_bot.tool_description.researcher_dry_run", description_args)
+ else
+ I18n.t("discourse_ai.ai_bot.tool_description.researcher", description_args)
+ end
+ end
+
+ def description_args
+ { count: @result_count || 0, filter: @filter || "", goals: @goals || "" }
+ end
+
+ protected
+
+ MIN_TOKENS_FOR_RESEARCH = 8000
+ def process_filter(filter, goals, post, &blk)
+ if llm.max_prompt_tokens < MIN_TOKENS_FOR_RESEARCH
+ raise ArgumentError,
+ "LLM max tokens too low for research. Minimum is #{MIN_TOKENS_FOR_RESEARCH}."
+ end
+ formatter =
+ DiscourseAi::Utils::Research::LlmFormatter.new(
+ filter,
+ max_tokens_per_batch: llm.max_prompt_tokens - 2000,
+ tokenizer: llm.tokenizer,
+ max_tokens_per_post: options[:max_tokens_per_post] || 2000,
+ )
+
+ results = []
+
+ formatter.each_chunk { |chunk| results << run_inference(chunk[:text], goals, post, &blk) }
+ { dry_run: false, goals: goals, filter: @filter, results: results }
+ end
+
+ def run_inference(chunk_text, goals, post, &blk)
+ system_prompt = goal_system_prompt(goals)
+ user_prompt = goal_user_prompt(goals, chunk_text)
+
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ system_prompt,
+ messages: [{ type: :user, content: user_prompt }],
+ post_id: post.id,
+ topic_id: post.topic_id,
+ )
+
+ results = []
+ llm.generate(
+ prompt,
+ user: post.user,
+ feature_name: context.feature_name,
+ cancel_manager: context.cancel_manager,
+ ) { |partial| results << partial }
+
+ @progress_dots ||= 0
+ @progress_dots += 1
+ blk.call(details + "\n\n#{"." * @progress_dots}")
+ results.join
+ end
+
+ def goal_system_prompt(goals)
+ <<~TEXT
+ You are a researcher tool designed to analyze and extract information from forum content.
+ Your task is to process the provided content and extract relevant information based on the specified goal.
+
+ Your goal is: #{goals}
+ TEXT
+ end
+
+ def goal_user_prompt(goals, chunk_text)
+ <<~TEXT
+ Here is the content to analyze:
+
+ {{{
+ #{chunk_text}
+ }}}
+
+ Your goal is: #{goals}
+ TEXT
+ end
+ end
+ end
+ end
+end
diff --git a/lib/personas/tools/tool.rb b/lib/personas/tools/tool.rb
index 4bafe0af..540204c2 100644
--- a/lib/personas/tools/tool.rb
+++ b/lib/personas/tools/tool.rb
@@ -47,8 +47,9 @@ module DiscourseAi
end
end
- attr_accessor :custom_raw, :parameters
- attr_reader :tool_call_id, :persona_options, :bot_user, :llm, :context
+ # llm being public makes it a bit easier to test
+ attr_accessor :custom_raw, :parameters, :llm
+ attr_reader :tool_call_id, :persona_options, :bot_user, :context
def initialize(
parameters,
diff --git a/lib/personas/tools/update_artifact.rb b/lib/personas/tools/update_artifact.rb
index ccf39a8d..46038d7a 100644
--- a/lib/personas/tools/update_artifact.rb
+++ b/lib/personas/tools/update_artifact.rb
@@ -159,6 +159,7 @@ module DiscourseAi
artifact: artifact,
artifact_version: artifact_version,
instructions: instructions,
+ cancel_manager: context.cancel_manager,
)
.apply do |progress|
partial_response << progress
diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb
index 15800087..8dbcda38 100644
--- a/lib/summarization/fold_content.rb
+++ b/lib/summarization/fold_content.rb
@@ -18,7 +18,7 @@ module DiscourseAi
attr_reader :bot, :strategy
# @param user { User } - User object used for auditing usage.
- # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
+ # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response.
# Note: The block is only called with results of the final summary, not intermediate summaries.
#
# This method doesn't care if we already have an up to date summary. It always regenerate.
@@ -77,7 +77,7 @@ module DiscourseAi
# @param items { Array } - Content to summarize. Structure will be: { poster: who wrote the content, id: a way to order content, text: content }
# @param user { User } - User object used for auditing usage.
- # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
+ # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response.
# Note: The block is only called with results of the final summary, not intermediate summaries.
#
# The summarization algorithm.
@@ -112,7 +112,7 @@ module DiscourseAi
summary = +""
buffer_blk =
- Proc.new do |partial, cancel, _, type|
+ Proc.new do |partial, _, type|
if type == :structured_output
json_summary_schema_key = bot.persona.response_format&.first.to_h
partial_summary =
@@ -120,12 +120,12 @@ module DiscourseAi
if partial_summary.present?
summary << partial_summary
- on_partial_blk.call(partial_summary, cancel) if on_partial_blk
+ on_partial_blk.call(partial_summary) if on_partial_blk
end
elsif type.blank?
# Assume response is a regular completion.
summary << partial
- on_partial_blk.call(partial, cancel) if on_partial_blk
+ on_partial_blk.call(partial) if on_partial_blk
end
end
diff --git a/lib/utils/research/filter.rb b/lib/utils/research/filter.rb
new file mode 100644
index 00000000..734943e7
--- /dev/null
+++ b/lib/utils/research/filter.rb
@@ -0,0 +1,263 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Utils
+ module Research
+ class Filter
+ # Stores custom filter handlers
+ def self.register_filter(matcher, &block)
+ (@registered_filters ||= {})[matcher] = block
+ end
+
+ def self.registered_filters
+ @registered_filters ||= {}
+ end
+
+ def self.word_to_date(str)
+ ::Search.word_to_date(str)
+ end
+
+ attr_reader :term, :filters, :order, :guardian, :limit, :offset
+
+ # Define all filters at class level
+ register_filter(/\Astatus:open\z/i) do |relation, _, _|
+ relation.where("topics.closed = false AND topics.archived = false")
+ end
+
+ register_filter(/\Astatus:closed\z/i) do |relation, _, _|
+ relation.where("topics.closed = true")
+ end
+
+ register_filter(/\Astatus:archived\z/i) do |relation, _, _|
+ relation.where("topics.archived = true")
+ end
+
+ register_filter(/\Astatus:noreplies\z/i) do |relation, _, _|
+ relation.where("topics.posts_count = 1")
+ end
+
+ register_filter(/\Astatus:single_user\z/i) do |relation, _, _|
+ relation.where("topics.participant_count = 1")
+ end
+
+ # Date filters
+ register_filter(/\Abefore:(.*)\z/i) do |relation, date_str, _|
+ if date = Filter.word_to_date(date_str)
+ relation.where("posts.created_at < ?", date)
+ else
+ relation
+ end
+ end
+
+ register_filter(/\Aafter:(.*)\z/i) do |relation, date_str, _|
+ if date = Filter.word_to_date(date_str)
+ relation.where("posts.created_at > ?", date)
+ else
+ relation
+ end
+ end
+
+ register_filter(/\Atopic_before:(.*)\z/i) do |relation, date_str, _|
+ if date = Filter.word_to_date(date_str)
+ relation.where("topics.created_at < ?", date)
+ else
+ relation
+ end
+ end
+
+ register_filter(/\Atopic_after:(.*)\z/i) do |relation, date_str, _|
+ if date = Filter.word_to_date(date_str)
+ relation.where("topics.created_at > ?", date)
+ else
+ relation
+ end
+ end
+
+ register_filter(/\A(?:tags?|tag):(.*)\z/i) do |relation, tag_param, _|
+ if tag_param.include?(",")
+ tag_names = tag_param.split(",").map(&:strip)
+ tag_ids = Tag.where(name: tag_names).pluck(:id)
+ return relation.where("1 = 0") if tag_ids.empty?
+ relation.where(topic_id: TopicTag.where(tag_id: tag_ids).select(:topic_id))
+ else
+ if tag = Tag.find_by(name: tag_param)
+ relation.where(topic_id: TopicTag.where(tag_id: tag.id).select(:topic_id))
+ else
+ relation.where("1 = 0")
+ end
+ end
+ end
+
+ register_filter(/\Akeywords?:(.*)\z/i) do |relation, keywords_param, _|
+ if keywords_param.blank?
+ relation
+ else
+ keywords = keywords_param.split(",").map(&:strip).reject(&:blank?)
+ if keywords.empty?
+ relation
+ else
+ # Build a ts_query string joined by | (OR)
+ ts_query = keywords.map { |kw| kw.gsub(/['\\]/, " ") }.join(" | ")
+ relation =
+ relation.joins("JOIN post_search_data ON post_search_data.post_id = posts.id")
+ relation.where(
+ "post_search_data.search_data @@ to_tsquery(?, ?)",
+ ::Search.ts_config,
+ ts_query,
+ )
+ end
+ end
+ end
+
+ register_filter(/\A(?:categories?|category):(.*)\z/i) do |relation, category_param, _|
+ if category_param.include?(",")
+ category_names = category_param.split(",").map(&:strip)
+
+ found_category_ids = []
+ category_names.each do |name|
+ category = Category.find_by(slug: name) || Category.find_by(name: name)
+ found_category_ids << category.id if category
+ end
+
+ return relation.where("1 = 0") if found_category_ids.empty?
+ relation.where(topic_id: Topic.where(category_id: found_category_ids).select(:id))
+ else
+ if category =
+ Category.find_by(slug: category_param) || Category.find_by(name: category_param)
+ relation.where(topic_id: Topic.where(category_id: category.id).select(:id))
+ else
+ relation.where("1 = 0")
+ end
+ end
+ end
+
+ register_filter(/\A\@(\w+)\z/i) do |relation, username, filter|
+ user = User.find_by(username_lower: username.downcase)
+ if user
+ relation.where("posts.user_id = ?", user.id)
+ else
+ relation.where("1 = 0") # No results if user doesn't exist
+ end
+ end
+
+ register_filter(/\Ain:posted\z/i) do |relation, _, filter|
+ if filter.guardian.user
+ relation.where("posts.user_id = ?", filter.guardian.user.id)
+ else
+ relation.where("1 = 0") # No results if not logged in
+ end
+ end
+
+ register_filter(/\Agroup:([a-zA-Z0-9_\-]+)\z/i) do |relation, name, filter|
+ group = Group.find_by("name ILIKE ?", name)
+ if group
+ relation.where(
+ "posts.user_id IN (
+ SELECT gu.user_id FROM group_users gu
+ WHERE gu.group_id = ?
+ )",
+ group.id,
+ )
+ else
+ relation.where("1 = 0") # No results if group doesn't exist
+ end
+ end
+
+ register_filter(/\Amax_results:(\d+)\z/i) do |relation, limit_str, filter|
+ filter.limit_by_user!(limit_str.to_i)
+ relation
+ end
+
+ register_filter(/\Aorder:latest\z/i) do |relation, order_str, filter|
+ filter.set_order!(:latest_post)
+ relation
+ end
+
+ register_filter(/\Aorder:oldest\z/i) do |relation, order_str, filter|
+ filter.set_order!(:oldest_post)
+ relation
+ end
+
+ register_filter(/\Aorder:latest_topic\z/i) do |relation, order_str, filter|
+ filter.set_order!(:latest_topic)
+ relation
+ end
+
+ register_filter(/\Aorder:oldest_topic\z/i) do |relation, order_str, filter|
+ filter.set_order!(:oldest_topic)
+ relation
+ end
+
+ def initialize(term, guardian: nil, limit: nil, offset: nil)
+ @term = term.to_s
+ @guardian = guardian || Guardian.new
+ @limit = limit
+ @offset = offset
+ @filters = []
+ @valid = true
+ @order = :latest_post
+
+ @term = process_filters(@term)
+ end
+
+ def set_order!(order)
+ @order = order
+ end
+
+ def limit_by_user!(limit)
+ @limit = limit if limit.to_i < @limit.to_i || @limit.nil?
+ end
+
+ def search
+ filtered = Post.secured(@guardian).joins(:topic).merge(Topic.secured(@guardian))
+
+ @filters.each do |filter_block, match_data|
+ filtered = filter_block.call(filtered, match_data, self)
+ end
+
+ filtered = filtered.limit(@limit) if @limit.to_i > 0
+ filtered = filtered.offset(@offset) if @offset.to_i > 0
+
+ if @order == :latest_post
+ filtered = filtered.order("posts.created_at DESC")
+ elsif @order == :oldest_post
+ filtered = filtered.order("posts.created_at ASC")
+ elsif @order == :latest_topic
+ filtered = filtered.order("topics.created_at DESC, posts.post_number DESC")
+ elsif @order == :oldest_topic
+ filtered = filtered.order("topics.created_at ASC, posts.post_number ASC")
+ end
+
+ filtered
+ end
+
+ private
+
+ def process_filters(term)
+ return "" if term.blank?
+
+ term
+ .to_s
+ .scan(/(([^" \t\n\x0B\f\r]+)?(("[^"]+")?))/)
+ .to_a
+ .map do |(word, _)|
+ next if word.blank?
+
+ found = false
+ self.class.registered_filters.each do |matcher, block|
+ if word =~ matcher
+ @filters << [block, $1]
+ found = true
+ break
+ end
+ end
+
+ found ? nil : word
+ end
+ .compact
+ .join(" ")
+ end
+ end
+ end
+ end
+end
diff --git a/lib/utils/research/llm_formatter.rb b/lib/utils/research/llm_formatter.rb
new file mode 100644
index 00000000..a762f2dd
--- /dev/null
+++ b/lib/utils/research/llm_formatter.rb
@@ -0,0 +1,205 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Utils
+ module Research
+ class LlmFormatter
+ def initialize(filter, max_tokens_per_batch:, tokenizer:, max_tokens_per_post:)
+ @filter = filter
+ @max_tokens_per_batch = max_tokens_per_batch
+ @tokenizer = tokenizer
+ @max_tokens_per_post = max_tokens_per_post
+ @to_process = filter_to_hash
+ end
+
+ def each_chunk
+ return nil if @to_process.empty?
+
+ result = { post_count: 0, topic_count: 0, text: +"" }
+ estimated_tokens = 0
+
+ @to_process.each do |topic_id, topic_data|
+ topic = Topic.find_by(id: topic_id)
+ next unless topic
+
+ topic_text, topic_tokens, post_count = format_topic(topic, topic_data[:posts])
+
+ # If this single topic exceeds our token limit and we haven't added anything yet,
+ # we need to include at least this one topic (partial content)
+ if estimated_tokens == 0 && topic_tokens > @max_tokens_per_batch
+ offset = 0
+ while offset < topic_text.length
+ chunk = +""
+ chunk_tokens = 0
+ lines = topic_text[offset..].lines
+ lines.each do |line|
+ line_tokens = estimate_tokens(line)
+ break if chunk_tokens + line_tokens > @max_tokens_per_batch
+ chunk << line
+ chunk_tokens += line_tokens
+ end
+ break if chunk.empty?
+ yield(
+ {
+ text: chunk,
+ post_count: post_count, # This may overcount if split mid-topic, but preserves original logic
+ topic_count: 1,
+ }
+ )
+ offset += chunk.length
+ end
+
+ next
+ end
+
+ # If adding this topic would exceed our token limit and we already have content, skip it
+ if estimated_tokens > 0 && estimated_tokens + topic_tokens > @max_tokens_per_batch
+ yield result if result[:text].present?
+ estimated_tokens = 0
+ result = { post_count: 0, topic_count: 0, text: +"" }
+ else
+ # Add this topic to the result
+ result[:text] << topic_text
+ result[:post_count] += post_count
+ result[:topic_count] += 1
+ estimated_tokens += topic_tokens
+ end
+ end
+ yield result if result[:text].present?
+
+ @to_process.clear
+ end
+
+ private
+
+ def filter_to_hash
+ hash = {}
+ @filter
+ .search
+ .pluck(:topic_id, :id, :post_number)
+ .each do |topic_id, post_id, post_number|
+ hash[topic_id] ||= { posts: [] }
+ hash[topic_id][:posts] << [post_id, post_number]
+ end
+
+ hash.each_value { |topic| topic[:posts].sort_by! { |_, post_number| post_number } }
+ hash
+ end
+
+ def format_topic(topic, posts_data)
+ text = ""
+ total_tokens = 0
+ post_count = 0
+
+ # Add topic header
+ text += format_topic_header(topic)
+ total_tokens += estimate_tokens(text)
+
+ # Get all post numbers in this topic
+ all_post_numbers = topic.posts.pluck(:post_number).sort
+
+ # Format posts with omitted information
+ first_post_number = posts_data.first[1]
+ last_post_number = posts_data.last[1]
+
+ # Handle posts before our selection
+ if first_post_number > 1
+ omitted_before = first_post_number - 1
+ text += format_omitted_posts(omitted_before, "before")
+ total_tokens += estimate_tokens(format_omitted_posts(omitted_before, "before"))
+ end
+
+ # Format each post
+ posts_data.each do |post_id, post_number|
+ post = Post.find_by(id: post_id)
+ next unless post
+
+ text += format_post(post)
+ total_tokens += estimate_tokens(format_post(post))
+ post_count += 1
+ end
+
+ # Handle posts after our selection
+ if last_post_number < all_post_numbers.last
+ omitted_after = all_post_numbers.last - last_post_number
+ text += format_omitted_posts(omitted_after, "after")
+ total_tokens += estimate_tokens(format_omitted_posts(omitted_after, "after"))
+ end
+
+ [text, total_tokens, post_count]
+ end
+
+ def format_topic_header(topic)
+ header = +"# #{topic.title}\n"
+
+ # Add category
+ header << "Category: #{topic.category.name}\n" if topic.category
+
+ # Add tags
+ header << "Tags: #{topic.tags.map(&:name).join(", ")}\n" if topic.tags.present?
+
+ # Add creation date
+ header << "Created: #{format_date(topic.created_at)}\n"
+ header << "Topic url: /t/#{topic.id}\n"
+ header << "Status: #{format_topic_status(topic)}\n\n"
+
+ header
+ end
+
+ def format_topic_status(topic)
+ solved = topic.respond_to?(:solved) && topic.solved.present?
+ solved_text = solved ? " (solved)" : ""
+ if topic.archived?
+ "Archived#{solved_text}"
+ elsif topic.closed?
+ "Closed#{solved_text}"
+ else
+ "Open#{solved_text}"
+ end
+ end
+
+ def format_post(post)
+ text = +"---\n"
+ text << "## Post by #{post.user&.username} - #{format_date(post.created_at)}\n\n"
+ text << "#{truncate_if_needed(post.raw)}\n"
+ text << "Likes: #{post.like_count}\n" if post.like_count.to_i > 0
+ text << "Post url: /t/-/#{post.topic_id}/#{post.post_number}\n\n"
+ text
+ end
+
+ def truncate_if_needed(content)
+ tokens_count = estimate_tokens(content)
+
+ return content if tokens_count <= @max_tokens_per_post
+
+ half_limit = @max_tokens_per_post / 2
+ token_ids = @tokenizer.encode(content)
+
+ first_half_ids = token_ids[0...half_limit]
+ last_half_ids = token_ids[-half_limit..-1]
+
+ first_text = @tokenizer.decode(first_half_ids)
+ last_text = @tokenizer.decode(last_half_ids)
+
+ "#{first_text}\n\n... elided #{tokens_count - @max_tokens_per_post} tokens ...\n\n#{last_text}"
+ end
+
+ def format_omitted_posts(count, position)
+ if position == "before"
+ "#{count} earlier #{count == 1 ? "post" : "posts"} omitted\n\n"
+ else
+ "#{count} later #{count == 1 ? "post" : "posts"} omitted\n\n"
+ end
+ end
+
+ def format_date(date)
+ date.strftime("%Y-%m-%d %H:%M")
+ end
+
+ def estimate_tokens(text)
+ @tokenizer.tokenize(text).length
+ end
+ end
+ end
+ end
+end
diff --git a/spec/lib/completions/cancel_manager_spec.rb b/spec/lib/completions/cancel_manager_spec.rb
new file mode 100644
index 00000000..57961826
--- /dev/null
+++ b/spec/lib/completions/cancel_manager_spec.rb
@@ -0,0 +1,106 @@
+# frozen_string_literal: true
+
+describe DiscourseAi::Completions::CancelManager do
+ fab!(:model) { Fabricate(:anthropic_model, name: "test-model") }
+
+ it "can stop monitoring for cancellation cleanly" do
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+ cancel_manager.start_monitor(delay: 100) { false }
+ expect(cancel_manager.monitor_thread).not_to be_nil
+ cancel_manager.stop_monitor
+ expect(cancel_manager.cancelled?).to eq(false)
+ expect(cancel_manager.monitor_thread).to be_nil
+ end
+
+ it "can monitor for cancellation" do
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+ results = [true, false, false]
+
+ cancel_manager.start_monitor(delay: 0) { results.pop }
+
+ wait_for { cancel_manager.cancelled? == true }
+ wait_for { cancel_manager.monitor_thread.nil? }
+
+ expect(cancel_manager.cancelled?).to eq(true)
+ expect(cancel_manager.monitor_thread).to be_nil
+ end
+
+ it "should do nothing when cancel manager is already cancelled" do
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+ cancel_manager.cancel!
+
+ llm = model.to_llm
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ "You are a test bot",
+ messages: [{ type: :user, content: "hello" }],
+ )
+
+ result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager)
+ expect(result).to be_nil
+ end
+
+ it "should be able to cancel a completion" do
+ # Start an HTTP server that hangs indefinitely
+ server = TCPServer.new("127.0.0.1", 0)
+ port = server.addr[1]
+
+ begin
+ thread =
+ Thread.new do
+ loop do
+ begin
+ _client = server.accept
+ sleep(30) # Hold the connection longer than the test will run
+ break
+ rescue StandardError
+ # Server closed
+ break
+ end
+ end
+ end
+
+ # Create a model that points to our hanging server
+ model.update!(url: "http://127.0.0.1:#{port}")
+
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+
+ completion_thread =
+ Thread.new do
+ llm = model.to_llm
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ "You are a test bot",
+ messages: [{ type: :user, content: "hello" }],
+ )
+
+ result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager)
+ expect(result).to be_nil
+ expect(cancel_manager.cancelled).to eq(true)
+ end
+
+ wait_for { cancel_manager.callbacks.size == 1 }
+
+ cancel_manager.cancel!
+ completion_thread.join(2)
+
+ expect(completion_thread).not_to be_alive
+ ensure
+ begin
+ server.close
+ rescue StandardError
+ nil
+ end
+ begin
+ thread.kill
+ rescue StandardError
+ nil
+ end
+ begin
+ completion_thread&.kill
+ rescue StandardError
+ nil
+ end
+ end
+ end
+end
diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb
index 130c735b..8ad95b95 100644
--- a/spec/lib/completions/endpoints/endpoint_compliance.rb
+++ b/spec/lib/completions/endpoints/endpoint_compliance.rb
@@ -188,9 +188,11 @@ class EndpointsCompliance
mock.stub_streamed_simple_call(dialect.translate) do
completion_response = +""
- endpoint.perform_completion!(dialect, user) do |partial, cancel|
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+
+ endpoint.perform_completion!(dialect, user, cancel_manager: cancel_manager) do |partial|
completion_response << partial
- cancel.call if completion_response.split(" ").length == 2
+ cancel_manager.cancel! if completion_response.split(" ").length == 2
end
expect(AiApiAuditLog.count).to eq(1)
@@ -212,12 +214,14 @@ class EndpointsCompliance
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+
mock.stub_streamed_tool_call(a_dialect.translate) do
buffered_partial = []
- endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
+ endpoint.perform_completion!(a_dialect, user, cancel_manager: cancel_manager) do |partial|
buffered_partial << partial
- cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
+ cancel_manager if partial.is_a?(DiscourseAi::Completions::ToolCall)
end
expect(buffered_partial).to eq([mock.invocation_response])
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index 4d21c94d..e9ad0ac3 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -1136,14 +1136,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do
split = body.split("|")
+ cancel_manager = DiscourseAi::Completions::CancelManager.new
+
count = 0
DiscourseAi::AiBot::PostStreamer.on_callback =
proc do |callback|
count += 1
- if count == 2
- last_post = third_post.topic.posts.order(:id).last
- Discourse.redis.del("gpt_cancel:#{last_post.id}")
- end
+ cancel_manager.cancel! if count == 2
raise "this should not happen" if count > 2
end
@@ -1155,13 +1154,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do
)
# we are going to need to use real data here cause we want to trigger the
# base endpoint to cancel part way through
- playground.reply_to(third_post)
+ playground.reply_to(third_post, cancel_manager: cancel_manager)
end
last_post = third_post.topic.posts.order(:id).last
- # not Hello123, we cancelled at 1 which means we may get 2 and then be done
- expect(last_post.raw).to eq("Hello12")
+ # not Hello123, we cancelled at 1
+ expect(last_post.raw).to eq("Hello1")
end
end
diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb
index e3c96e34..fa81876e 100644
--- a/spec/lib/personas/persona_spec.rb
+++ b/spec/lib/personas/persona_spec.rb
@@ -218,32 +218,28 @@ RSpec.describe DiscourseAi::Personas::Persona do
SiteSetting.ai_google_custom_search_cx = "abc123"
# should be ordered by priority and then alpha
- expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to eq(
- [
- DiscourseAi::Personas::General,
- DiscourseAi::Personas::Artist,
- DiscourseAi::Personas::Creative,
- DiscourseAi::Personas::DiscourseHelper,
- DiscourseAi::Personas::GithubHelper,
- DiscourseAi::Personas::Researcher,
- DiscourseAi::Personas::SettingsExplorer,
- DiscourseAi::Personas::SqlHelper,
- ],
+ expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to contain_exactly(
+ DiscourseAi::Personas::General,
+ DiscourseAi::Personas::Artist,
+ DiscourseAi::Personas::Creative,
+ DiscourseAi::Personas::DiscourseHelper,
+ DiscourseAi::Personas::GithubHelper,
+ DiscourseAi::Personas::Researcher,
+ DiscourseAi::Personas::SettingsExplorer,
+ DiscourseAi::Personas::SqlHelper,
)
# it should allow staff access to WebArtifactCreator
- expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to eq(
- [
- DiscourseAi::Personas::General,
- DiscourseAi::Personas::Artist,
- DiscourseAi::Personas::Creative,
- DiscourseAi::Personas::DiscourseHelper,
- DiscourseAi::Personas::GithubHelper,
- DiscourseAi::Personas::Researcher,
- DiscourseAi::Personas::SettingsExplorer,
- DiscourseAi::Personas::SqlHelper,
- DiscourseAi::Personas::WebArtifactCreator,
- ],
+ expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to contain_exactly(
+ DiscourseAi::Personas::General,
+ DiscourseAi::Personas::Artist,
+ DiscourseAi::Personas::Creative,
+ DiscourseAi::Personas::DiscourseHelper,
+ DiscourseAi::Personas::GithubHelper,
+ DiscourseAi::Personas::Researcher,
+ DiscourseAi::Personas::SettingsExplorer,
+ DiscourseAi::Personas::SqlHelper,
+ DiscourseAi::Personas::WebArtifactCreator,
)
# omits personas if key is missing
diff --git a/spec/lib/personas/tools/researcher_spec.rb b/spec/lib/personas/tools/researcher_spec.rb
new file mode 100644
index 00000000..51227001
--- /dev/null
+++ b/spec/lib/personas/tools/researcher_spec.rb
@@ -0,0 +1,109 @@
+# frozen_string_literal: true
+
+RSpec.describe DiscourseAi::Personas::Tools::Researcher do
+ before { SearchIndexer.enable }
+ after { SearchIndexer.disable }
+
+ fab!(:llm_model)
+ let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
+ let(:progress_blk) { Proc.new {} }
+
+ fab!(:admin)
+ fab!(:user)
+ fab!(:category) { Fabricate(:category, name: "research-category") }
+ fab!(:tag_research) { Fabricate(:tag, name: "research") }
+ fab!(:tag_data) { Fabricate(:tag, name: "data") }
+
+ fab!(:topic_with_tags) { Fabricate(:topic, category: category, tags: [tag_research, tag_data]) }
+ fab!(:post) { Fabricate(:post, topic: topic_with_tags) }
+
+ before { SiteSetting.ai_bot_enabled = true }
+
+ describe "#invoke" do
+ it "returns filter information and result count" do
+ researcher =
+ described_class.new(
+ { filter: "tag:research after:2023", goals: "analyze post patterns", dry_run: true },
+ bot_user: bot_user,
+ llm: llm,
+ context: DiscourseAi::Personas::BotContext.new(user: user, post: post),
+ )
+
+ results = researcher.invoke(&progress_blk)
+
+ expect(results[:filter]).to eq("tag:research after:2023")
+ expect(results[:goals]).to eq("analyze post patterns")
+ expect(results[:dry_run]).to eq(true)
+ expect(results[:number_of_results]).to be > 0
+ expect(researcher.filter).to eq("tag:research after:2023")
+ expect(researcher.result_count).to be > 0
+ end
+
+ it "handles empty filters" do
+ researcher =
+ described_class.new({ goals: "analyze all content" }, bot_user: bot_user, llm: llm)
+
+ results = researcher.invoke(&progress_blk)
+
+ expect(results[:error]).to eq("No filter provided")
+ end
+
+ it "accepts max_results option" do
+ researcher =
+ described_class.new(
+ { filter: "category:research-category" },
+ persona_options: {
+ "max_results" => "50",
+ },
+ bot_user: bot_user,
+ llm: llm,
+ )
+
+ expect(researcher.options[:max_results]).to eq(50)
+ end
+
+ it "returns correct results for non-dry-run with filtered posts" do
+ # Stage 2 topics, each with 2 posts
+ topics = Array.new(2) { Fabricate(:topic, category: category, tags: [tag_research]) }
+ topics.flat_map do |topic|
+ [
+ Fabricate(:post, topic: topic, raw: "Relevant content 1", user: user),
+ Fabricate(:post, topic: topic, raw: "Relevant content 2", user: admin),
+ ]
+ end
+
+ # Filter to posts by user in research-category
+ researcher =
+ described_class.new(
+ {
+ filter: "category:research-category @#{user.username}",
+ goals: "find relevant content",
+ dry_run: false,
+ },
+ bot_user: bot_user,
+ llm: llm,
+ context: DiscourseAi::Personas::BotContext.new(user: user, post: post),
+ )
+
+ responses = 10.times.map { |i| ["Found: Relevant content #{i + 1}"] }
+ results = nil
+
+ last_progress = nil
+ progress_blk = Proc.new { |response| last_progress = response }
+
+ DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
+ researcher.llm = llm_model.to_llm
+ results = researcher.invoke(&progress_blk)
+ end
+
+ expect(last_progress).to include("find relevant content")
+ expect(last_progress).to include("category:research-category")
+
+ expect(results[:dry_run]).to eq(false)
+ expect(results[:goals]).to eq("find relevant content")
+ expect(results[:filter]).to eq("category:research-category @#{user.username}")
+ expect(results[:results].first).to include("Found: Relevant content 1")
+ end
+ end
+end
diff --git a/spec/lib/utils/research/filter_spec.rb b/spec/lib/utils/research/filter_spec.rb
new file mode 100644
index 00000000..655d1705
--- /dev/null
+++ b/spec/lib/utils/research/filter_spec.rb
@@ -0,0 +1,142 @@
+# frozen_string_literal: true
+
+describe DiscourseAi::Utils::Research::Filter do
+ describe "integration tests" do
+ before_all { SiteSetting.min_topic_title_length = 3 }
+
+ fab!(:user)
+
+ fab!(:feature_tag) { Fabricate(:tag, name: "feature") }
+ fab!(:bug_tag) { Fabricate(:tag, name: "bug") }
+
+ fab!(:announcement_category) { Fabricate(:category, name: "Announcements") }
+ fab!(:feedback_category) { Fabricate(:category, name: "Feedback") }
+
+ fab!(:feature_topic) do
+ Fabricate(
+ :topic,
+ user: user,
+ tags: [feature_tag],
+ category: announcement_category,
+ title: "New Feature Discussion",
+ )
+ end
+
+ fab!(:bug_topic) do
+ Fabricate(
+ :topic,
+ tags: [bug_tag],
+ user: user,
+ category: announcement_category,
+ title: "Bug Report",
+ )
+ end
+
+ fab!(:feature_bug_topic) do
+ Fabricate(
+ :topic,
+ tags: [feature_tag, bug_tag],
+ user: user,
+ category: feedback_category,
+ title: "Feature with Bug",
+ )
+ end
+
+ fab!(:no_tag_topic) do
+ Fabricate(:topic, user: user, category: feedback_category, title: "General Discussion")
+ end
+
+ fab!(:feature_post) { Fabricate(:post, topic: feature_topic, user: user) }
+ fab!(:bug_post) { Fabricate(:post, topic: bug_topic, user: user) }
+ fab!(:feature_bug_post) { Fabricate(:post, topic: feature_bug_topic, user: user) }
+ fab!(:no_tag_post) { Fabricate(:post, topic: no_tag_topic, user: user) }
+
+ describe "tag filtering" do
+ it "correctly filters posts by tags" do
+ filter = described_class.new("tag:feature")
+ expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, feature_bug_post.id)
+
+ filter = described_class.new("tag:feature,bug")
+ expect(filter.search.pluck(:id)).to contain_exactly(
+ feature_bug_post.id,
+ bug_post.id,
+ feature_post.id,
+ )
+
+ filter = described_class.new("tags:bug")
+ expect(filter.search.pluck(:id)).to contain_exactly(bug_post.id, feature_bug_post.id)
+
+ filter = described_class.new("tag:nonexistent")
+ expect(filter.search.count).to eq(0)
+ end
+ end
+
+ describe "category filtering" do
+ it "correctly filters posts by categories" do
+ filter = described_class.new("category:Announcements")
+ expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, bug_post.id)
+
+ filter = described_class.new("category:Announcements,Feedback")
+ expect(filter.search.pluck(:id)).to contain_exactly(
+ feature_post.id,
+ bug_post.id,
+ feature_bug_post.id,
+ no_tag_post.id,
+ )
+
+ filter = described_class.new("categories:Feedback")
+ expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id, no_tag_post.id)
+
+ filter = described_class.new("category:Feedback tag:feature")
+ expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id)
+ end
+ end
+
+ it "can limit number of results" do
+ filter = described_class.new("category:Feedback max_results:1", limit: 5)
+ expect(filter.search.pluck(:id).length).to eq(1)
+ end
+
+ describe "full text keyword searching" do
+ before_all { SearchIndexer.enable }
+ fab!(:post_with_apples) do
+ Fabricate(:post, raw: "This post contains apples", topic: feature_topic, user: user)
+ end
+
+ fab!(:post_with_bananas) do
+ Fabricate(:post, raw: "This post mentions bananas", topic: bug_topic, user: user)
+ end
+
+ fab!(:post_with_both) do
+ Fabricate(
+ :post,
+ raw: "This post has apples and bananas",
+ topic: feature_bug_topic,
+ user: user,
+ )
+ end
+
+ fab!(:post_with_none) do
+ Fabricate(:post, raw: "No fruits here", topic: no_tag_topic, user: user)
+ end
+
+ it "correctly filters posts by full text keywords" do
+ filter = described_class.new("keywords:apples")
+ expect(filter.search.pluck(:id)).to contain_exactly(post_with_apples.id, post_with_both.id)
+
+ filter = described_class.new("keywords:bananas")
+ expect(filter.search.pluck(:id)).to contain_exactly(post_with_bananas.id, post_with_both.id)
+
+ filter = described_class.new("keywords:apples,bananas")
+ expect(filter.search.pluck(:id)).to contain_exactly(
+ post_with_apples.id,
+ post_with_bananas.id,
+ post_with_both.id,
+ )
+
+ filter = described_class.new("keywords:oranges")
+ expect(filter.search.count).to eq(0)
+ end
+ end
+ end
+end
diff --git a/spec/lib/utils/research/llm_formatter_spec.rb b/spec/lib/utils/research/llm_formatter_spec.rb
new file mode 100644
index 00000000..edc10363
--- /dev/null
+++ b/spec/lib/utils/research/llm_formatter_spec.rb
@@ -0,0 +1,74 @@
+# frozen_string_literal: true
+#
+describe DiscourseAi::Utils::Research::LlmFormatter do
+ fab!(:user) { Fabricate(:user, username: "test_user") }
+ fab!(:topic) { Fabricate(:topic, title: "This is a Test Topic", user: user) }
+ fab!(:post) { Fabricate(:post, topic: topic, user: user) }
+ let(:tokenizer) { DiscourseAi::Tokenizer::OpenAiTokenizer }
+ let(:filter) { DiscourseAi::Utils::Research::Filter.new("@#{user.username}") }
+
+ describe "#truncate_if_needed" do
+ it "returns original content when under token limit" do
+ formatter =
+ described_class.new(
+ filter,
+ max_tokens_per_batch: 1000,
+ tokenizer: tokenizer,
+ max_tokens_per_post: 100,
+ )
+
+ short_text = "This is a short post"
+ expect(formatter.send(:truncate_if_needed, short_text)).to eq(short_text)
+ end
+
+ it "truncates content when over token limit" do
+ # Create a post with content that will exceed our token limit
+ long_text = ("word " * 200).strip
+
+ formatter =
+ described_class.new(
+ filter,
+ max_tokens_per_batch: 1000,
+ tokenizer: tokenizer,
+ max_tokens_per_post: 50,
+ )
+
+ truncated = formatter.send(:truncate_if_needed, long_text)
+
+ expect(truncated).to include("... elided 150 tokens ...")
+ expect(truncated).to_not eq(long_text)
+
+ # Should have roughly 25 words before and 25 after (half of max_tokens_per_post)
+ first_chunk = truncated.split("\n\n")[0]
+ expect(first_chunk.split(" ").length).to be_within(5).of(25)
+
+ last_chunk = truncated.split("\n\n")[2]
+ expect(last_chunk.split(" ").length).to be_within(5).of(25)
+ end
+ end
+
+ describe "#format_post" do
+ it "formats posts with truncation for long content" do
+ # Set up a post with long content
+ long_content = ("word " * 200).strip
+ long_post = Fabricate(:post, raw: long_content, topic: topic, user: user)
+
+ formatter =
+ described_class.new(
+ filter,
+ max_tokens_per_batch: 1000,
+ tokenizer: tokenizer,
+ max_tokens_per_post: 50,
+ )
+
+ formatted = formatter.send(:format_post, long_post)
+
+ # Should have standard formatting elements
+ expect(formatted).to include("## Post by #{user.username}")
+ expect(formatted).to include("Post url: /t/-/#{long_post.topic_id}/#{long_post.post_number}")
+
+ # Should include truncation marker
+ expect(formatted).to include("... elided 150 tokens ...")
+ end
+ end
+end