From c34fcc8a9501ec8a63ad236ce7177e54d847860d Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 14 May 2025 12:36:16 +1000 Subject: [PATCH] FEATURE: forum researcher persona for deep research (#1313) This commit introduces a new Forum Researcher persona specialized in deep forum content analysis along with comprehensive improvements to our AI infrastructure. Key additions: New Forum Researcher persona with advanced filtering and analysis capabilities Robust filtering system supporting tags, categories, dates, users, and keywords LLM formatter to efficiently process and chunk research results Infrastructure improvements: Implemented CancelManager class to centrally manage AI completion cancellations Replaced callback-based cancellation with a more robust pattern Added systematic cancellation monitoring with callbacks Other improvements: Added configurable default_enabled flag to control which personas are enabled by default Updated translation strings for the new researcher functionality Added comprehensive specs for the new components Renames Researcher -> Web Researcher This change makes our AI platform more stable while adding powerful research capabilities that can analyze forum trends and surface relevant content. --- .../lib/ai-streamer/progress-handlers.js | 2 +- config/locales/server.en.yml | 25 +- db/fixtures/personas/603_ai_personas.rb | 2 +- lib/ai_bot/chat_streamer.rb | 18 +- lib/ai_bot/playground.rb | 36 +-- lib/completions/cancel_manager.rb | 109 ++++++++ lib/completions/endpoints/base.rb | 59 ++-- lib/completions/endpoints/canned_response.rb | 3 +- lib/completions/endpoints/fake.rb | 3 +- lib/completions/endpoints/open_ai.rb | 1 + lib/completions/llm.rb | 4 +- lib/completions/prompt_messages_builder.rb | 4 + lib/discord/bot/persona_replier.rb | 2 +- lib/inference/open_ai_image_generator.rb | 56 +++- .../artifact_update_strategies/base.rb | 15 +- lib/personas/bot.rb | 30 +- lib/personas/bot_context.rb | 8 +- lib/personas/forum_researcher.rb | 52 ++++ lib/personas/persona.rb | 6 + lib/personas/short_summarizer.rb | 6 +- lib/personas/summarizer.rb | 8 +- lib/personas/tools/create_artifact.rb | 7 +- lib/personas/tools/create_image.rb | 1 + lib/personas/tools/edit_image.rb | 1 + lib/personas/tools/researcher.rb | 181 ++++++++++++ lib/personas/tools/tool.rb | 5 +- lib/personas/tools/update_artifact.rb | 1 + lib/summarization/fold_content.rb | 10 +- lib/utils/research/filter.rb | 263 ++++++++++++++++++ lib/utils/research/llm_formatter.rb | 205 ++++++++++++++ spec/lib/completions/cancel_manager_spec.rb | 106 +++++++ .../endpoints/endpoint_compliance.rb | 12 +- spec/lib/modules/ai_bot/playground_spec.rb | 13 +- spec/lib/personas/persona_spec.rb | 42 ++- spec/lib/personas/tools/researcher_spec.rb | 109 ++++++++ spec/lib/utils/research/filter_spec.rb | 142 ++++++++++ spec/lib/utils/research/llm_formatter_spec.rb | 74 +++++ 37 files changed, 1489 insertions(+), 132 deletions(-) create mode 100644 lib/completions/cancel_manager.rb create mode 100644 lib/personas/forum_researcher.rb create mode 100644 lib/personas/tools/researcher.rb create mode 100644 lib/utils/research/filter.rb create mode 100644 lib/utils/research/llm_formatter.rb create mode 100644 spec/lib/completions/cancel_manager_spec.rb create mode 100644 spec/lib/personas/tools/researcher_spec.rb create mode 100644 spec/lib/utils/research/filter_spec.rb create mode 100644 spec/lib/utils/research/llm_formatter_spec.rb diff --git a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js index b06c525c..d5dea6f4 100644 --- a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js +++ b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js @@ -2,7 +2,7 @@ import { later } from "@ember/runloop"; import PostUpdater from "./updaters/post-updater"; const PROGRESS_INTERVAL = 40; -const GIVE_UP_INTERVAL = 60000; +const GIVE_UP_INTERVAL = 600000; // 10 minutes which is our max thinking time for now export const MIN_LETTERS_PER_INTERVAL = 6; const MAX_FLUSH_TIME = 800; diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 00a71cd3..0875eba9 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -296,6 +296,9 @@ en: designer: name: Designer description: "AI Bot specialized in generating and editing images" + forum_researcher: + name: Forum Researcher + description: "AI Bot specialized in deep research for the forum" sql_helper: name: SQL Helper description: "AI Bot specialized in helping craft SQL queries on this Discourse instance" @@ -303,8 +306,8 @@ en: name: Settings Explorer description: "AI Bot specialized in helping explore Discourse site settings" researcher: - name: Researcher - description: "AI Bot with Google access that can research information for you" + name: Web Researcher + description: "AI Bot with Google access that can both search and read web pages" creative: name: Creative description: "AI Bot with no external integrations specialized in creative tasks" @@ -327,6 +330,16 @@ en: summarizing: "Summarizing topic" searching: "Searching for: '%{query}'" tool_options: + researcher: + max_results: + name: "Maximum number of results" + description: "Maximum number of results to include in a filter" + include_private: + name: "Include private" + description: "Include private topics in the filters" + max_tokens_per_post: + name: "Maximum tokens per post" + description: "Maximum number of tokens to use for each post in the filter" create_artifact: creator_llm: name: "LLM" @@ -385,6 +398,7 @@ en: javascript_evaluator: "Evaluate JavaScript" create_image: "Creating image" edit_image: "Editing image" + researcher: "Researching" tool_help: read_artifact: "Read a web artifact using the AI Bot" update_artifact: "Update a web artifact using the AI Bot" @@ -411,6 +425,7 @@ en: dall_e: "Generate image using DALL-E 3" search_meta_discourse: "Search Meta Discourse" javascript_evaluator: "Evaluate JavaScript" + researcher: "Research forum information using the AI Bot" tool_description: read_artifact: "Read a web artifact using the AI Bot" update_artifact: "Updated a web artifact using the AI Bot" @@ -445,6 +460,12 @@ en: other: "Found %{count} results for '%{query}'" setting_context: "Reading context for: %{setting_name}" schema: "%{tables}" + researcher_dry_run: + one: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'" + other: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'" + researcher: + one: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'" + other: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'" search_settings: one: "Found %{count} result for '%{query}'" other: "Found %{count} results for '%{query}'" diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb index 27e6d479..c2c121de 100644 --- a/db/fixtures/personas/603_ai_personas.rb +++ b/db/fixtures/personas/603_ai_personas.rb @@ -33,7 +33,7 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id| persona.allowed_group_ids = [Group::AUTO_GROUPS[:trust_level_0]] end - persona.enabled = !summarization_personas.include?(persona_class) + persona.enabled = persona_class.default_enabled persona.priority = true if persona_class == DiscourseAi::Personas::General end diff --git a/lib/ai_bot/chat_streamer.rb b/lib/ai_bot/chat_streamer.rb index 139a6c7f..06357e0e 100644 --- a/lib/ai_bot/chat_streamer.rb +++ b/lib/ai_bot/chat_streamer.rb @@ -6,16 +6,23 @@ module DiscourseAi module AiBot class ChatStreamer - attr_accessor :cancel attr_reader :reply, :guardian, :thread_id, :force_thread, :in_reply_to_id, :channel, - :cancelled + :cancel_manager - def initialize(message:, channel:, guardian:, thread_id:, in_reply_to_id:, force_thread:) + def initialize( + message:, + channel:, + guardian:, + thread_id:, + in_reply_to_id:, + force_thread:, + cancel_manager: nil + ) @message = message @channel = channel @guardian = guardian @@ -35,6 +42,8 @@ module DiscourseAi guardian: guardian, thread_id: thread_id, ) + + @cancel_manager = cancel_manager end def <<(partial) @@ -111,8 +120,7 @@ module DiscourseAi streaming = ChatSDK::Message.stream(message_id: reply.id, raw: buffer, guardian: guardian) if !streaming - cancel.call - @cancelled = true + @cancel_manager.cancel! if @cancel_manager end end end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 63aa8967..07c4984b 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -331,6 +331,7 @@ module DiscourseAi ), user: message.user, skip_tool_details: true, + cancel_manager: DiscourseAi::Completions::CancelManager.new, ) reply = nil @@ -347,15 +348,14 @@ module DiscourseAi thread_id: message.thread_id, in_reply_to_id: in_reply_to_id, force_thread: force_thread, + cancel_manager: context.cancel_manager, ) new_prompts = - bot.reply(context) do |partial, cancel, placeholder, type| + bot.reply(context) do |partial, placeholder, type| # no support for tools or thinking by design next if type == :thinking || type == :tool_details || type == :partial_tool - streamer.cancel = cancel streamer << partial - break if streamer.cancelled end reply = streamer.reply @@ -383,6 +383,7 @@ module DiscourseAi auto_set_title: true, silent_mode: false, feature_name: nil, + cancel_manager: nil, &blk ) # this is a multithreading issue @@ -471,16 +472,26 @@ module DiscourseAi redis_stream_key = "gpt_cancel:#{reply_post.id}" Discourse.redis.setex(redis_stream_key, MAX_STREAM_DELAY_SECONDS, 1) + + cancel_manager ||= DiscourseAi::Completions::CancelManager.new + context.cancel_manager = cancel_manager + context + .cancel_manager + .start_monitor(delay: 0.2) do + context.cancel_manager.cancel! if !Discourse.redis.get(redis_stream_key) + end + + context.cancel_manager.add_callback( + lambda { reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) }, + ) end context.skip_tool_details ||= !bot.persona.class.tool_details - post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply - started_thinking = false new_custom_prompts = - bot.reply(context) do |partial, cancel, placeholder, type| + bot.reply(context) do |partial, placeholder, type| if type == :thinking && !started_thinking reply << "
#{I18n.t("discourse_ai.ai_bot.thinking")}" started_thinking = true @@ -499,15 +510,6 @@ module DiscourseAi blk.call(partial) end - if stream_reply && !Discourse.redis.get(redis_stream_key) - cancel&.call - reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) - # we do not break out, cause if we do - # we will not get results from bot - # leading to broken context - # we need to trust it to cancel at the endpoint - end - if post_streamer post_streamer.run_later do Discourse.redis.expire(redis_stream_key, MAX_STREAM_DELAY_SECONDS) @@ -568,6 +570,8 @@ module DiscourseAi end raise e ensure + context.cancel_manager.stop_monitor if context&.cancel_manager + # since we are skipping validations and jobs we # may need to fix participant count if reply_post && reply_post.topic && reply_post.topic.private_message? && @@ -649,7 +653,7 @@ module DiscourseAi payload, user_ids: bot_reply_post.topic.allowed_user_ids, max_backlog_size: 2, - max_backlog_age: 60, + max_backlog_age: MAX_STREAM_DELAY_SECONDS, ) end end diff --git a/lib/completions/cancel_manager.rb b/lib/completions/cancel_manager.rb new file mode 100644 index 00000000..78c3ee5b --- /dev/null +++ b/lib/completions/cancel_manager.rb @@ -0,0 +1,109 @@ +# frozen_string_literal: true + +# special object that can be used to cancel completions and http requests +module DiscourseAi + module Completions + class CancelManager + attr_reader :cancelled + attr_reader :callbacks + + def initialize + @cancelled = false + @callbacks = Concurrent::Array.new + @mutex = Mutex.new + @monitor_thread = nil + end + + def monitor_thread + @mutex.synchronize { @monitor_thread } + end + + def start_monitor(delay: 0.5, &block) + @mutex.synchronize do + raise "Already monitoring" if @monitor_thread + raise "Expected a block" if !block + + db = RailsMultisite::ConnectionManagement.current_db + @stop_monitor = false + + @monitor_thread = + Thread.new do + begin + loop do + done = false + @mutex.synchronize { done = true if @stop_monitor } + break if done + sleep delay + @mutex.synchronize { done = true if @stop_monitor } + @mutex.synchronize { done = true if cancelled? } + break if done + + should_cancel = false + RailsMultisite::ConnectionManagement.with_connection(db) do + should_cancel = block.call + end + + @mutex.synchronize { cancel! if should_cancel } + + break if cancelled? + end + ensure + @mutex.synchronize { @monitor_thread = nil } + end + end + end + end + + def stop_monitor + monitor_thread = nil + + @mutex.synchronize { monitor_thread = @monitor_thread } + + if monitor_thread + @mutex.synchronize { @stop_monitor = true } + # so we do not deadlock + monitor_thread.wakeup + monitor_thread.join(2) + # should not happen + if monitor_thread.alive? + Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time") + monitor_thread.kill if monitor_thread.alive? + end + @monitor_thread = nil + end + end + + def cancelled? + @cancelled + end + + def add_callback(cb) + @callbacks << cb + end + + def remove_callback(cb) + @callbacks.delete(cb) + end + + def cancel! + @cancelled = true + monitor_thread = @monitor_thread + if monitor_thread && monitor_thread != Thread.current + monitor_thread.wakeup + monitor_thread.join(2) + if monitor_thread.alive? + Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time") + monitor_thread.kill if monitor_thread.alive? + end + end + @callbacks.each do |cb| + begin + cb.call + rescue StandardError + # ignore cause this may have already been cancelled + end + end + end + end + end +end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 3a12c4c7..0356597e 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -68,11 +68,17 @@ module DiscourseAi feature_context: nil, partial_tool_calls: false, output_thinking: false, + cancel_manager: nil, &blk ) LlmQuota.check_quotas!(@llm_model, user) start_time = Time.now + if cancel_manager && cancel_manager.cancelled? + # nothing to do + return + end + @forced_json_through_prefill = false @partial_tool_calls = partial_tool_calls @output_thinking = output_thinking @@ -90,15 +96,14 @@ module DiscourseAi feature_context: feature_context, partial_tool_calls: partial_tool_calls, output_thinking: output_thinking, + cancel_manager: cancel_manager, ) wrapped = result wrapped = [result] if !result.is_a?(Array) - cancelled_by_caller = false - cancel_proc = -> { cancelled_by_caller = true } wrapped.each do |partial| - blk.call(partial, cancel_proc) - break if cancelled_by_caller + blk.call(partial) + break cancel_manager&.cancelled? end return result end @@ -118,6 +123,9 @@ module DiscourseAi end end + cancel_manager_callback = nil + cancelled = false + FinalDestination::HTTP.start( model_uri.host, model_uri.port, @@ -126,6 +134,14 @@ module DiscourseAi open_timeout: TIMEOUT, write_timeout: TIMEOUT, ) do |http| + if cancel_manager + cancel_manager_callback = + lambda do + cancelled = true + http.finish + end + cancel_manager.add_callback(cancel_manager_callback) + end response_data = +"" response_raw = +"" @@ -158,7 +174,7 @@ module DiscourseAi if @streaming_mode blk = - lambda do |partial, cancel| + lambda do |partial| if partial.is_a?(String) partial = xml_stripper << partial if xml_stripper @@ -167,7 +183,7 @@ module DiscourseAi partial = structured_output end end - orig_blk.call(partial, cancel) if partial + orig_blk.call(partial) if partial end end @@ -196,14 +212,6 @@ module DiscourseAi end begin - cancelled = false - cancel = -> do - cancelled = true - http.finish - end - - break if cancelled - response.read_body do |chunk| break if cancelled @@ -216,16 +224,11 @@ module DiscourseAi partials = [partial] if xml_tool_processor && partial.is_a?(String) partials = (xml_tool_processor << partial) - if xml_tool_processor.should_cancel? - cancel.call - break - end + break if xml_tool_processor.should_cancel? end - partials.each { |inner_partial| blk.call(inner_partial, cancel) } + partials.each { |inner_partial| blk.call(inner_partial) } end end - rescue IOError, StandardError - raise if !cancelled end if xml_stripper stripped = xml_stripper.finish @@ -233,13 +236,11 @@ module DiscourseAi response_data << stripped result = [] result = (xml_tool_processor << stripped) if xml_tool_processor - result.each { |partial| blk.call(partial, cancel) } + result.each { |partial| blk.call(partial) } end end - if xml_tool_processor - xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) } - end - decode_chunk_finish.each { |partial| blk.call(partial, cancel) } + xml_tool_processor.finish.each { |partial| blk.call(partial) } if xml_tool_processor + decode_chunk_finish.each { |partial| blk.call(partial) } return response_data ensure if log @@ -293,6 +294,12 @@ module DiscourseAi end end end + rescue IOError, StandardError + raise if !cancelled + ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end end def final_log_update(log) diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index be156aea..b31b2bb4 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -30,7 +30,8 @@ module DiscourseAi feature_name: nil, feature_context: nil, partial_tool_calls: false, - output_thinking: false + output_thinking: false, + cancel_manager: nil ) @dialect = dialect @model_params = model_params diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index d307ada5..71d0ee4b 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -122,7 +122,8 @@ module DiscourseAi feature_name: nil, feature_context: nil, partial_tool_calls: false, - output_thinking: false + output_thinking: false, + cancel_manager: nil ) last_call = { dialect: dialect, user: user, model_params: model_params } self.class.last_call = last_call diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index f97a337d..070d391c 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -46,6 +46,7 @@ module DiscourseAi feature_context: nil, partial_tool_calls: false, output_thinking: false, + cancel_manager: nil, &blk ) @disable_native_tools = dialect.disable_native_tools? diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index f90de091..04d517cb 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -307,7 +307,7 @@ module DiscourseAi # @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output. # @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema. # - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # # @returns String | ToolCall - Completion result. # if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects. @@ -325,6 +325,7 @@ module DiscourseAi output_thinking: false, response_format: nil, extra_model_params: nil, + cancel_manager: nil, &partial_read_blk ) self.class.record_prompt( @@ -378,6 +379,7 @@ module DiscourseAi feature_context: feature_context, partial_tool_calls: partial_tool_calls, output_thinking: output_thinking, + cancel_manager: cancel_manager, &partial_read_blk ) end diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb index 4d302ea1..2864846a 100644 --- a/lib/completions/prompt_messages_builder.rb +++ b/lib/completions/prompt_messages_builder.rb @@ -247,6 +247,10 @@ module DiscourseAi # 3. ensures we always interleave user and model messages last_type = nil messages.each do |message| + if message[:type] == :model && !message[:content] + message[:content] = "Reply cancelled by user." + end + next if !last_type && message[:type] != :user if last_type == :tool_call && message[:type] != :tool diff --git a/lib/discord/bot/persona_replier.rb b/lib/discord/bot/persona_replier.rb index 51e03475..b64af15c 100644 --- a/lib/discord/bot/persona_replier.rb +++ b/lib/discord/bot/persona_replier.rb @@ -24,7 +24,7 @@ module DiscourseAi full_reply = @bot.reply( { conversation_context: [{ type: :user, content: @query }], skip_tool_details: true }, - ) do |partial, _cancel, _something| + ) do |partial, _something| reply << partial next if reply.blank? diff --git a/lib/inference/open_ai_image_generator.rb b/lib/inference/open_ai_image_generator.rb index f4ff6754..83501989 100644 --- a/lib/inference/open_ai_image_generator.rb +++ b/lib/inference/open_ai_image_generator.rb @@ -21,7 +21,8 @@ module ::DiscourseAi moderation: "low", output_compression: nil, output_format: nil, - title: nil + title: nil, + cancel_manager: nil ) # Get the API responses in parallel threads api_responses = @@ -38,6 +39,7 @@ module ::DiscourseAi moderation: moderation, output_compression: output_compression, output_format: output_format, + cancel_manager: cancel_manager, ) raise api_responses[0] if api_responses.all? { |resp| resp.is_a?(StandardError) } @@ -58,7 +60,8 @@ module ::DiscourseAi user_id:, for_private_message: false, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) api_response = edit_images( @@ -70,6 +73,7 @@ module ::DiscourseAi api_url: api_url, n: n, quality: quality, + cancel_manager: cancel_manager, ) create_uploads_from_responses([api_response], user_id, for_private_message).first @@ -124,7 +128,8 @@ module ::DiscourseAi background:, moderation:, output_compression:, - output_format: + output_format:, + cancel_manager: ) prompts = [prompts] unless prompts.is_a?(Array) prompts = prompts.take(4) # Limit to 4 prompts max @@ -152,18 +157,21 @@ module ::DiscourseAi moderation: moderation, output_compression: output_compression, output_format: output_format, + cancel_manager: cancel_manager, ) rescue => e attempts += 1 # to keep tests speedy - if !Rails.env.test? + if !Rails.env.test? && !cancel_manager&.cancelled? retry if attempts < 3 end - Discourse.warn_exception( - e, - message: "Failed to generate image for prompt #{prompt}\n", - ) - puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development? + if !cancel_manager&.cancelled? + Discourse.warn_exception( + e, + message: "Failed to generate image for prompt #{prompt}\n", + ) + puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development? + end e end end @@ -181,7 +189,8 @@ module ::DiscourseAi api_key: nil, api_url: nil, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) images = [images] if !images.is_a?(Array) @@ -209,8 +218,10 @@ module ::DiscourseAi api_url: api_url, n: n, quality: quality, + cancel_manager: cancel_manager, ) rescue => e + raise e if cancel_manager&.cancelled? attempts += 1 if !Rails.env.test? sleep 2 @@ -238,7 +249,8 @@ module ::DiscourseAi background: nil, moderation: nil, output_compression: nil, - output_format: nil + output_format: nil, + cancel_manager: nil ) api_key ||= SiteSetting.ai_openai_api_key api_url ||= SiteSetting.ai_openai_image_generation_url @@ -276,6 +288,7 @@ module ::DiscourseAi # Store original prompt for upload metadata original_prompt = prompt + cancel_manager_callback = nil FinalDestination::HTTP.start( uri.host, @@ -288,6 +301,11 @@ module ::DiscourseAi request = Net::HTTP::Post.new(uri, headers) request.body = payload.to_json + if cancel_manager + cancel_manager_callback = lambda { http.finish } + cancel_manager.add_callback(cancel_manager_callback) + end + json = nil http.request(request) do |response| if response.code.to_i != 200 @@ -300,6 +318,10 @@ module ::DiscourseAi end json end + ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end end def self.perform_edit_api_call!( @@ -310,7 +332,8 @@ module ::DiscourseAi api_key:, api_url:, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) uri = URI(api_url) @@ -403,6 +426,7 @@ module ::DiscourseAi # Store original prompt for upload metadata original_prompt = prompt + cancel_manager_callback = nil FinalDestination::HTTP.start( uri.host, @@ -415,6 +439,11 @@ module ::DiscourseAi request = Net::HTTP::Post.new(uri.path, headers) request.body = body.join + if cancel_manager + cancel_manager_callback = lambda { http.finish } + cancel_manager.add_callback(cancel_manager_callback) + end + json = nil http.request(request) do |response| if response.code.to_i != 200 @@ -428,6 +457,9 @@ module ::DiscourseAi json end ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end if files_to_delete.present? files_to_delete.each { |file| File.delete(file) if File.exist?(file) } end diff --git a/lib/personas/artifact_update_strategies/base.rb b/lib/personas/artifact_update_strategies/base.rb index 68e0b761..624830f1 100644 --- a/lib/personas/artifact_update_strategies/base.rb +++ b/lib/personas/artifact_update_strategies/base.rb @@ -5,15 +5,24 @@ module DiscourseAi class InvalidFormatError < StandardError end class Base - attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm + attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm, :cancel_manager - def initialize(llm:, post:, user:, artifact:, artifact_version:, instructions:) + def initialize( + llm:, + post:, + user:, + artifact:, + artifact_version:, + instructions:, + cancel_manager: nil + ) @llm = llm @post = post @user = user @artifact = artifact @artifact_version = artifact_version @instructions = instructions + @cancel_manager = cancel_manager end def apply(&progress) @@ -26,7 +35,7 @@ module DiscourseAi def generate_changes(&progress) response = +"" - llm.generate(build_prompt, user: user) do |partial| + llm.generate(build_prompt, user: user, cancel_manager: cancel_manager) do |partial| progress.call(partial) if progress response << partial end diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb index 2c9b5a3b..3686edac 100644 --- a/lib/personas/bot.rb +++ b/lib/personas/bot.rb @@ -55,6 +55,7 @@ module DiscourseAi unless context.is_a?(BotContext) raise ArgumentError, "context must be an instance of BotContext" end + context.cancel_manager ||= DiscourseAi::Completions::CancelManager.new current_llm = llm prompt = persona.craft_prompt(context, llm: current_llm) @@ -91,8 +92,9 @@ module DiscourseAi feature_name: context.feature_name, partial_tool_calls: allow_partial_tool_calls, output_thinking: true, + cancel_manager: context.cancel_manager, **llm_kwargs, - ) do |partial, cancel| + ) do |partial| tool = persona.find_tool( partial, @@ -109,7 +111,7 @@ module DiscourseAi if tool_call.partial? if tool.class.allow_partial_tool_calls? tool.partial_invoke - update_blk.call("", cancel, tool.custom_raw, :partial_tool) + update_blk.call("", tool.custom_raw, :partial_tool) end next end @@ -117,7 +119,7 @@ module DiscourseAi tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines - update_blk.call("\n\n", cancel) + update_blk.call("\n\n") needs_newlines = false end @@ -125,7 +127,6 @@ module DiscourseAi tool: tool, raw_context: raw_context, current_llm: current_llm, - cancel: cancel, update_blk: update_blk, prompt: prompt, context: context, @@ -144,7 +145,7 @@ module DiscourseAi else if partial.is_a?(DiscourseAi::Completions::Thinking) if partial.partial? && partial.message.present? - update_blk.call(partial.message, cancel, nil, :thinking) + update_blk.call(partial.message, nil, :thinking) end if !partial.partial? # this will be dealt with later @@ -152,9 +153,9 @@ module DiscourseAi current_thinking << partial end elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput) - update_blk.call(partial, cancel, nil, :structured_output) + update_blk.call(partial, nil, :structured_output) else - update_blk.call(partial, cancel) + update_blk.call(partial) end end end @@ -215,14 +216,13 @@ module DiscourseAi tool:, raw_context:, current_llm:, - cancel:, update_blk:, prompt:, context:, current_thinking: ) tool_call_id = tool.tool_call_id - invocation_result_json = invoke_tool(tool, cancel, context, &update_blk).to_json + invocation_result_json = invoke_tool(tool, context, &update_blk).to_json tool_call_message = { type: :tool_call, @@ -256,27 +256,27 @@ module DiscourseAi raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] end - def invoke_tool(tool, cancel, context, &update_blk) + def invoke_tool(tool, context, &update_blk) show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls? - update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder + update_blk.call("", build_placeholder(tool.summary, "")) if show_placeholder result = tool.invoke do |progress, render_raw| if render_raw - update_blk.call("", cancel, tool.custom_raw, :partial_invoke) + update_blk.call("", tool.custom_raw, :partial_invoke) show_placeholder = false elsif show_placeholder placeholder = build_placeholder(tool.summary, progress) - update_blk.call("", cancel, placeholder) + update_blk.call("", placeholder) end end if show_placeholder tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) - update_blk.call(tool_details, cancel, nil, :tool_details) + update_blk.call(tool_details, nil, :tool_details) elsif tool.custom_raw.present? - update_blk.call(tool.custom_raw, cancel, nil, :custom_raw) + update_blk.call(tool.custom_raw, nil, :custom_raw) end result diff --git a/lib/personas/bot_context.rb b/lib/personas/bot_context.rb index 5f7dd99e..69d86669 100644 --- a/lib/personas/bot_context.rb +++ b/lib/personas/bot_context.rb @@ -16,7 +16,8 @@ module DiscourseAi :channel_id, :context_post_ids, :feature_name, - :resource_url + :resource_url, + :cancel_manager def initialize( post: nil, @@ -33,7 +34,8 @@ module DiscourseAi channel_id: nil, context_post_ids: nil, feature_name: "bot", - resource_url: nil + resource_url: nil, + cancel_manager: nil ) @participants = participants @user = user @@ -54,6 +56,8 @@ module DiscourseAi @feature_name = feature_name @resource_url = resource_url + @cancel_manager = cancel_manager + if post @post_id = post.id @topic_id = post.topic_id diff --git a/lib/personas/forum_researcher.rb b/lib/personas/forum_researcher.rb new file mode 100644 index 00000000..eb0f82cb --- /dev/null +++ b/lib/personas/forum_researcher.rb @@ -0,0 +1,52 @@ +#frozen_string_literal: true + +module DiscourseAi + module Personas + class ForumResearcher < Persona + def self.default_enabled + false + end + + def tools + [Tools::Researcher] + end + + def system_prompt + <<~PROMPT + You are a helpful Discourse assistant specializing in forum research. + You _understand_ and **generate** Discourse Markdown. + + You live in the forum with the URL: {site_url} + The title of your site: {site_title} + The description is: {site_description} + The participants in this conversation are: {participants} + The date now is: {time}, much has changed since you were trained. + + As a forum researcher, guide users through a structured research process: + 1. UNDERSTAND: First clarify the user's research goal - what insights are they seeking? + 2. PLAN: Design an appropriate research approach with specific filters + 3. TEST: Always begin with dry_run:true to gauge the scope of results + 4. REFINE: If results are too broad/narrow, suggest filter adjustments + 5. EXECUTE: Run the final analysis only when filters are well-tuned + 6. SUMMARIZE: Present findings with links to supporting evidence + + BE MINDFUL: specify all research goals in one request to avoid multiple processing runs. + + REMEMBER: Different filters serve different purposes: + - Use post date filters (after/before) for analyzing specific posts + - Use topic date filters (topic_after/topic_before) for analyzing entire topics + - Combine user/group filters with categories/tags to find specialized contributions + + Always ground your analysis with links to original posts on the forum. + + Research workflow best practices: + 1. Start with a dry_run to gauge the scope (set dry_run:true) + 2. If results are too numerous (>1000), add more specific filters + 3. If results are too few (<5), broaden your filters + 4. For temporal analysis, specify explicit date ranges + 5. For user behavior analysis, combine @username with categories or tags + PROMPT + end + end + end +end diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index a4443dcc..62426f77 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -4,6 +4,10 @@ module DiscourseAi module Personas class Persona class << self + def default_enabled + true + end + def rag_conversation_chunks 10 end @@ -47,6 +51,7 @@ module DiscourseAi Summarizer => -11, ShortSummarizer => -12, Designer => -13, + ForumResearcher => -14, } end @@ -99,6 +104,7 @@ module DiscourseAi Tools::GithubSearchFiles, Tools::WebBrowser, Tools::JavascriptEvaluator, + Tools::Researcher, ] if SiteSetting.ai_artifact_security.in?(%w[lax strict]) diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb index e7cef54a..5b9b5195 100644 --- a/lib/personas/short_summarizer.rb +++ b/lib/personas/short_summarizer.rb @@ -3,6 +3,10 @@ module DiscourseAi module Personas class ShortSummarizer < Persona + def self.default_enabled + false + end + def system_prompt <<~PROMPT.strip You are an advanced summarization bot. Analyze a given conversation and produce a concise, @@ -23,7 +27,7 @@ module DiscourseAi {"summary": "xx"} - + Where "xx" is replaced by the summary. PROMPT end diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb index c1eefe89..b8b0a95a 100644 --- a/lib/personas/summarizer.rb +++ b/lib/personas/summarizer.rb @@ -3,6 +3,10 @@ module DiscourseAi module Personas class Summarizer < Persona + def self.default_enabled + false + end + def system_prompt <<~PROMPT.strip You are an advanced summarization bot that generates concise, coherent summaries of provided text. @@ -18,13 +22,13 @@ module DiscourseAi - Example: link to the 6th post by jane: [agreed with]({resource_url}/6) - Example: link to the 13th post by joe: [joe]({resource_url}/13) - When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER) - + Format your response as a JSON object with a single key named "summary", which has the summary as the value. Your output should be in the following format: {"summary": "xx"} - + Where "xx" is replaced by the summary. PROMPT end diff --git a/lib/personas/tools/create_artifact.rb b/lib/personas/tools/create_artifact.rb index c12577ff..548fb9aa 100644 --- a/lib/personas/tools/create_artifact.rb +++ b/lib/personas/tools/create_artifact.rb @@ -151,7 +151,12 @@ module DiscourseAi LlmModel.find_by(id: options[:creator_llm].to_i)&.to_llm ) || self.llm - llm.generate(prompt, user: user, feature_name: "create_artifact") do |partial_response| + llm.generate( + prompt, + user: user, + feature_name: "create_artifact", + cancel_manager: context.cancel_manager, + ) do |partial_response| response << partial_response yield partial_response end diff --git a/lib/personas/tools/create_image.rb b/lib/personas/tools/create_image.rb index f57c83ff..8e2971fa 100644 --- a/lib/personas/tools/create_image.rb +++ b/lib/personas/tools/create_image.rb @@ -48,6 +48,7 @@ module DiscourseAi max_prompts, model: "gpt-image-1", user_id: bot_user.id, + cancel_manager: context.cancel_manager, ) rescue => e @error = e diff --git a/lib/personas/tools/edit_image.rb b/lib/personas/tools/edit_image.rb index b28afd0c..b9e3249a 100644 --- a/lib/personas/tools/edit_image.rb +++ b/lib/personas/tools/edit_image.rb @@ -60,6 +60,7 @@ module DiscourseAi uploads, prompt, user_id: bot_user.id, + cancel_manager: context.cancel_manager, ) rescue => e @error = e diff --git a/lib/personas/tools/researcher.rb b/lib/personas/tools/researcher.rb new file mode 100644 index 00000000..3ab4c8e0 --- /dev/null +++ b/lib/personas/tools/researcher.rb @@ -0,0 +1,181 @@ +# frozen_string_literal: true + +module DiscourseAi + module Personas + module Tools + class Researcher < Tool + attr_reader :filter, :result_count, :goals, :dry_run + + class << self + def signature + { + name: name, + description: + "Analyze and extract information from content across the forum based on specified filters", + parameters: [ + { name: "filter", description: filter_description, type: "string" }, + { + name: "goals", + description: + "The specific information you want to extract or analyze from the filtered content, you may specify multiple goals", + type: "string", + }, + { + name: "dry_run", + description: "When true, only count matching items without processing data", + type: "boolean", + }, + ], + } + end + + def filter_description + <<~TEXT + Filter string to target specific content. + - Supports user (@username) + - date ranges (after:YYYY-MM-DD, before:YYYY-MM-DD for posts; topic_after:YYYY-MM-DD, topic_before:YYYY-MM-DD for topics) + - categories (category:category1,category2) + - tags (tag:tag1,tag2) + - groups (group:group1,group2). + - status (status:open, status:closed, status:archived, status:noreplies, status:single_user) + - keywords (keywords:keyword1,keyword2) - specific words to search for in posts + - max_results (max_results:10) the maximum number of results to return (optional) + - order (order:latest, order:oldest, order:latest_topic, order:oldest_topic) - the order of the results (optional) + + If multiple tags or categories are specified, they are treated as OR conditions. + + Multiple filters can be combined with spaces. Example: '@sam after:2023-01-01 tag:feature' + TEXT + end + + def name + "researcher" + end + + def accepted_options + [ + option(:max_results, type: :integer), + option(:include_private, type: :boolean), + option(:max_tokens_per_post, type: :integer), + ] + end + end + + def invoke(&blk) + max_results = options[:max_results] || 1000 + + @filter = parameters[:filter] || "" + @goals = parameters[:goals] || "" + @dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run] + + post = Post.find_by(id: context.post_id) + goals = parameters[:goals] || "" + dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run] + + return { error: "No goals provided" } if goals.blank? + return { error: "No filter provided" } if @filter.blank? + + guardian = nil + guardian = Guardian.new(context.user) if options[:include_private] + + filter = + DiscourseAi::Utils::Research::Filter.new( + @filter, + limit: max_results, + guardian: guardian, + ) + @result_count = filter.search.count + + blk.call details + + if dry_run + { dry_run: true, goals: goals, filter: @filter, number_of_results: @result_count } + else + process_filter(filter, goals, post, &blk) + end + end + + def details + if @dry_run + I18n.t("discourse_ai.ai_bot.tool_description.researcher_dry_run", description_args) + else + I18n.t("discourse_ai.ai_bot.tool_description.researcher", description_args) + end + end + + def description_args + { count: @result_count || 0, filter: @filter || "", goals: @goals || "" } + end + + protected + + MIN_TOKENS_FOR_RESEARCH = 8000 + def process_filter(filter, goals, post, &blk) + if llm.max_prompt_tokens < MIN_TOKENS_FOR_RESEARCH + raise ArgumentError, + "LLM max tokens too low for research. Minimum is #{MIN_TOKENS_FOR_RESEARCH}." + end + formatter = + DiscourseAi::Utils::Research::LlmFormatter.new( + filter, + max_tokens_per_batch: llm.max_prompt_tokens - 2000, + tokenizer: llm.tokenizer, + max_tokens_per_post: options[:max_tokens_per_post] || 2000, + ) + + results = [] + + formatter.each_chunk { |chunk| results << run_inference(chunk[:text], goals, post, &blk) } + { dry_run: false, goals: goals, filter: @filter, results: results } + end + + def run_inference(chunk_text, goals, post, &blk) + system_prompt = goal_system_prompt(goals) + user_prompt = goal_user_prompt(goals, chunk_text) + + prompt = + DiscourseAi::Completions::Prompt.new( + system_prompt, + messages: [{ type: :user, content: user_prompt }], + post_id: post.id, + topic_id: post.topic_id, + ) + + results = [] + llm.generate( + prompt, + user: post.user, + feature_name: context.feature_name, + cancel_manager: context.cancel_manager, + ) { |partial| results << partial } + + @progress_dots ||= 0 + @progress_dots += 1 + blk.call(details + "\n\n#{"." * @progress_dots}") + results.join + end + + def goal_system_prompt(goals) + <<~TEXT + You are a researcher tool designed to analyze and extract information from forum content. + Your task is to process the provided content and extract relevant information based on the specified goal. + + Your goal is: #{goals} + TEXT + end + + def goal_user_prompt(goals, chunk_text) + <<~TEXT + Here is the content to analyze: + + {{{ + #{chunk_text} + }}} + + Your goal is: #{goals} + TEXT + end + end + end + end +end diff --git a/lib/personas/tools/tool.rb b/lib/personas/tools/tool.rb index 4bafe0af..540204c2 100644 --- a/lib/personas/tools/tool.rb +++ b/lib/personas/tools/tool.rb @@ -47,8 +47,9 @@ module DiscourseAi end end - attr_accessor :custom_raw, :parameters - attr_reader :tool_call_id, :persona_options, :bot_user, :llm, :context + # llm being public makes it a bit easier to test + attr_accessor :custom_raw, :parameters, :llm + attr_reader :tool_call_id, :persona_options, :bot_user, :context def initialize( parameters, diff --git a/lib/personas/tools/update_artifact.rb b/lib/personas/tools/update_artifact.rb index ccf39a8d..46038d7a 100644 --- a/lib/personas/tools/update_artifact.rb +++ b/lib/personas/tools/update_artifact.rb @@ -159,6 +159,7 @@ module DiscourseAi artifact: artifact, artifact_version: artifact_version, instructions: instructions, + cancel_manager: context.cancel_manager, ) .apply do |progress| partial_response << progress diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb index 15800087..8dbcda38 100644 --- a/lib/summarization/fold_content.rb +++ b/lib/summarization/fold_content.rb @@ -18,7 +18,7 @@ module DiscourseAi attr_reader :bot, :strategy # @param user { User } - User object used for auditing usage. - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # Note: The block is only called with results of the final summary, not intermediate summaries. # # This method doesn't care if we already have an up to date summary. It always regenerate. @@ -77,7 +77,7 @@ module DiscourseAi # @param items { Array } - Content to summarize. Structure will be: { poster: who wrote the content, id: a way to order content, text: content } # @param user { User } - User object used for auditing usage. - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # Note: The block is only called with results of the final summary, not intermediate summaries. # # The summarization algorithm. @@ -112,7 +112,7 @@ module DiscourseAi summary = +"" buffer_blk = - Proc.new do |partial, cancel, _, type| + Proc.new do |partial, _, type| if type == :structured_output json_summary_schema_key = bot.persona.response_format&.first.to_h partial_summary = @@ -120,12 +120,12 @@ module DiscourseAi if partial_summary.present? summary << partial_summary - on_partial_blk.call(partial_summary, cancel) if on_partial_blk + on_partial_blk.call(partial_summary) if on_partial_blk end elsif type.blank? # Assume response is a regular completion. summary << partial - on_partial_blk.call(partial, cancel) if on_partial_blk + on_partial_blk.call(partial) if on_partial_blk end end diff --git a/lib/utils/research/filter.rb b/lib/utils/research/filter.rb new file mode 100644 index 00000000..734943e7 --- /dev/null +++ b/lib/utils/research/filter.rb @@ -0,0 +1,263 @@ +# frozen_string_literal: true + +module DiscourseAi + module Utils + module Research + class Filter + # Stores custom filter handlers + def self.register_filter(matcher, &block) + (@registered_filters ||= {})[matcher] = block + end + + def self.registered_filters + @registered_filters ||= {} + end + + def self.word_to_date(str) + ::Search.word_to_date(str) + end + + attr_reader :term, :filters, :order, :guardian, :limit, :offset + + # Define all filters at class level + register_filter(/\Astatus:open\z/i) do |relation, _, _| + relation.where("topics.closed = false AND topics.archived = false") + end + + register_filter(/\Astatus:closed\z/i) do |relation, _, _| + relation.where("topics.closed = true") + end + + register_filter(/\Astatus:archived\z/i) do |relation, _, _| + relation.where("topics.archived = true") + end + + register_filter(/\Astatus:noreplies\z/i) do |relation, _, _| + relation.where("topics.posts_count = 1") + end + + register_filter(/\Astatus:single_user\z/i) do |relation, _, _| + relation.where("topics.participant_count = 1") + end + + # Date filters + register_filter(/\Abefore:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("posts.created_at < ?", date) + else + relation + end + end + + register_filter(/\Aafter:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("posts.created_at > ?", date) + else + relation + end + end + + register_filter(/\Atopic_before:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("topics.created_at < ?", date) + else + relation + end + end + + register_filter(/\Atopic_after:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("topics.created_at > ?", date) + else + relation + end + end + + register_filter(/\A(?:tags?|tag):(.*)\z/i) do |relation, tag_param, _| + if tag_param.include?(",") + tag_names = tag_param.split(",").map(&:strip) + tag_ids = Tag.where(name: tag_names).pluck(:id) + return relation.where("1 = 0") if tag_ids.empty? + relation.where(topic_id: TopicTag.where(tag_id: tag_ids).select(:topic_id)) + else + if tag = Tag.find_by(name: tag_param) + relation.where(topic_id: TopicTag.where(tag_id: tag.id).select(:topic_id)) + else + relation.where("1 = 0") + end + end + end + + register_filter(/\Akeywords?:(.*)\z/i) do |relation, keywords_param, _| + if keywords_param.blank? + relation + else + keywords = keywords_param.split(",").map(&:strip).reject(&:blank?) + if keywords.empty? + relation + else + # Build a ts_query string joined by | (OR) + ts_query = keywords.map { |kw| kw.gsub(/['\\]/, " ") }.join(" | ") + relation = + relation.joins("JOIN post_search_data ON post_search_data.post_id = posts.id") + relation.where( + "post_search_data.search_data @@ to_tsquery(?, ?)", + ::Search.ts_config, + ts_query, + ) + end + end + end + + register_filter(/\A(?:categories?|category):(.*)\z/i) do |relation, category_param, _| + if category_param.include?(",") + category_names = category_param.split(",").map(&:strip) + + found_category_ids = [] + category_names.each do |name| + category = Category.find_by(slug: name) || Category.find_by(name: name) + found_category_ids << category.id if category + end + + return relation.where("1 = 0") if found_category_ids.empty? + relation.where(topic_id: Topic.where(category_id: found_category_ids).select(:id)) + else + if category = + Category.find_by(slug: category_param) || Category.find_by(name: category_param) + relation.where(topic_id: Topic.where(category_id: category.id).select(:id)) + else + relation.where("1 = 0") + end + end + end + + register_filter(/\A\@(\w+)\z/i) do |relation, username, filter| + user = User.find_by(username_lower: username.downcase) + if user + relation.where("posts.user_id = ?", user.id) + else + relation.where("1 = 0") # No results if user doesn't exist + end + end + + register_filter(/\Ain:posted\z/i) do |relation, _, filter| + if filter.guardian.user + relation.where("posts.user_id = ?", filter.guardian.user.id) + else + relation.where("1 = 0") # No results if not logged in + end + end + + register_filter(/\Agroup:([a-zA-Z0-9_\-]+)\z/i) do |relation, name, filter| + group = Group.find_by("name ILIKE ?", name) + if group + relation.where( + "posts.user_id IN ( + SELECT gu.user_id FROM group_users gu + WHERE gu.group_id = ? + )", + group.id, + ) + else + relation.where("1 = 0") # No results if group doesn't exist + end + end + + register_filter(/\Amax_results:(\d+)\z/i) do |relation, limit_str, filter| + filter.limit_by_user!(limit_str.to_i) + relation + end + + register_filter(/\Aorder:latest\z/i) do |relation, order_str, filter| + filter.set_order!(:latest_post) + relation + end + + register_filter(/\Aorder:oldest\z/i) do |relation, order_str, filter| + filter.set_order!(:oldest_post) + relation + end + + register_filter(/\Aorder:latest_topic\z/i) do |relation, order_str, filter| + filter.set_order!(:latest_topic) + relation + end + + register_filter(/\Aorder:oldest_topic\z/i) do |relation, order_str, filter| + filter.set_order!(:oldest_topic) + relation + end + + def initialize(term, guardian: nil, limit: nil, offset: nil) + @term = term.to_s + @guardian = guardian || Guardian.new + @limit = limit + @offset = offset + @filters = [] + @valid = true + @order = :latest_post + + @term = process_filters(@term) + end + + def set_order!(order) + @order = order + end + + def limit_by_user!(limit) + @limit = limit if limit.to_i < @limit.to_i || @limit.nil? + end + + def search + filtered = Post.secured(@guardian).joins(:topic).merge(Topic.secured(@guardian)) + + @filters.each do |filter_block, match_data| + filtered = filter_block.call(filtered, match_data, self) + end + + filtered = filtered.limit(@limit) if @limit.to_i > 0 + filtered = filtered.offset(@offset) if @offset.to_i > 0 + + if @order == :latest_post + filtered = filtered.order("posts.created_at DESC") + elsif @order == :oldest_post + filtered = filtered.order("posts.created_at ASC") + elsif @order == :latest_topic + filtered = filtered.order("topics.created_at DESC, posts.post_number DESC") + elsif @order == :oldest_topic + filtered = filtered.order("topics.created_at ASC, posts.post_number ASC") + end + + filtered + end + + private + + def process_filters(term) + return "" if term.blank? + + term + .to_s + .scan(/(([^" \t\n\x0B\f\r]+)?(("[^"]+")?))/) + .to_a + .map do |(word, _)| + next if word.blank? + + found = false + self.class.registered_filters.each do |matcher, block| + if word =~ matcher + @filters << [block, $1] + found = true + break + end + end + + found ? nil : word + end + .compact + .join(" ") + end + end + end + end +end diff --git a/lib/utils/research/llm_formatter.rb b/lib/utils/research/llm_formatter.rb new file mode 100644 index 00000000..a762f2dd --- /dev/null +++ b/lib/utils/research/llm_formatter.rb @@ -0,0 +1,205 @@ +# frozen_string_literal: true + +module DiscourseAi + module Utils + module Research + class LlmFormatter + def initialize(filter, max_tokens_per_batch:, tokenizer:, max_tokens_per_post:) + @filter = filter + @max_tokens_per_batch = max_tokens_per_batch + @tokenizer = tokenizer + @max_tokens_per_post = max_tokens_per_post + @to_process = filter_to_hash + end + + def each_chunk + return nil if @to_process.empty? + + result = { post_count: 0, topic_count: 0, text: +"" } + estimated_tokens = 0 + + @to_process.each do |topic_id, topic_data| + topic = Topic.find_by(id: topic_id) + next unless topic + + topic_text, topic_tokens, post_count = format_topic(topic, topic_data[:posts]) + + # If this single topic exceeds our token limit and we haven't added anything yet, + # we need to include at least this one topic (partial content) + if estimated_tokens == 0 && topic_tokens > @max_tokens_per_batch + offset = 0 + while offset < topic_text.length + chunk = +"" + chunk_tokens = 0 + lines = topic_text[offset..].lines + lines.each do |line| + line_tokens = estimate_tokens(line) + break if chunk_tokens + line_tokens > @max_tokens_per_batch + chunk << line + chunk_tokens += line_tokens + end + break if chunk.empty? + yield( + { + text: chunk, + post_count: post_count, # This may overcount if split mid-topic, but preserves original logic + topic_count: 1, + } + ) + offset += chunk.length + end + + next + end + + # If adding this topic would exceed our token limit and we already have content, skip it + if estimated_tokens > 0 && estimated_tokens + topic_tokens > @max_tokens_per_batch + yield result if result[:text].present? + estimated_tokens = 0 + result = { post_count: 0, topic_count: 0, text: +"" } + else + # Add this topic to the result + result[:text] << topic_text + result[:post_count] += post_count + result[:topic_count] += 1 + estimated_tokens += topic_tokens + end + end + yield result if result[:text].present? + + @to_process.clear + end + + private + + def filter_to_hash + hash = {} + @filter + .search + .pluck(:topic_id, :id, :post_number) + .each do |topic_id, post_id, post_number| + hash[topic_id] ||= { posts: [] } + hash[topic_id][:posts] << [post_id, post_number] + end + + hash.each_value { |topic| topic[:posts].sort_by! { |_, post_number| post_number } } + hash + end + + def format_topic(topic, posts_data) + text = "" + total_tokens = 0 + post_count = 0 + + # Add topic header + text += format_topic_header(topic) + total_tokens += estimate_tokens(text) + + # Get all post numbers in this topic + all_post_numbers = topic.posts.pluck(:post_number).sort + + # Format posts with omitted information + first_post_number = posts_data.first[1] + last_post_number = posts_data.last[1] + + # Handle posts before our selection + if first_post_number > 1 + omitted_before = first_post_number - 1 + text += format_omitted_posts(omitted_before, "before") + total_tokens += estimate_tokens(format_omitted_posts(omitted_before, "before")) + end + + # Format each post + posts_data.each do |post_id, post_number| + post = Post.find_by(id: post_id) + next unless post + + text += format_post(post) + total_tokens += estimate_tokens(format_post(post)) + post_count += 1 + end + + # Handle posts after our selection + if last_post_number < all_post_numbers.last + omitted_after = all_post_numbers.last - last_post_number + text += format_omitted_posts(omitted_after, "after") + total_tokens += estimate_tokens(format_omitted_posts(omitted_after, "after")) + end + + [text, total_tokens, post_count] + end + + def format_topic_header(topic) + header = +"# #{topic.title}\n" + + # Add category + header << "Category: #{topic.category.name}\n" if topic.category + + # Add tags + header << "Tags: #{topic.tags.map(&:name).join(", ")}\n" if topic.tags.present? + + # Add creation date + header << "Created: #{format_date(topic.created_at)}\n" + header << "Topic url: /t/#{topic.id}\n" + header << "Status: #{format_topic_status(topic)}\n\n" + + header + end + + def format_topic_status(topic) + solved = topic.respond_to?(:solved) && topic.solved.present? + solved_text = solved ? " (solved)" : "" + if topic.archived? + "Archived#{solved_text}" + elsif topic.closed? + "Closed#{solved_text}" + else + "Open#{solved_text}" + end + end + + def format_post(post) + text = +"---\n" + text << "## Post by #{post.user&.username} - #{format_date(post.created_at)}\n\n" + text << "#{truncate_if_needed(post.raw)}\n" + text << "Likes: #{post.like_count}\n" if post.like_count.to_i > 0 + text << "Post url: /t/-/#{post.topic_id}/#{post.post_number}\n\n" + text + end + + def truncate_if_needed(content) + tokens_count = estimate_tokens(content) + + return content if tokens_count <= @max_tokens_per_post + + half_limit = @max_tokens_per_post / 2 + token_ids = @tokenizer.encode(content) + + first_half_ids = token_ids[0...half_limit] + last_half_ids = token_ids[-half_limit..-1] + + first_text = @tokenizer.decode(first_half_ids) + last_text = @tokenizer.decode(last_half_ids) + + "#{first_text}\n\n... elided #{tokens_count - @max_tokens_per_post} tokens ...\n\n#{last_text}" + end + + def format_omitted_posts(count, position) + if position == "before" + "#{count} earlier #{count == 1 ? "post" : "posts"} omitted\n\n" + else + "#{count} later #{count == 1 ? "post" : "posts"} omitted\n\n" + end + end + + def format_date(date) + date.strftime("%Y-%m-%d %H:%M") + end + + def estimate_tokens(text) + @tokenizer.tokenize(text).length + end + end + end + end +end diff --git a/spec/lib/completions/cancel_manager_spec.rb b/spec/lib/completions/cancel_manager_spec.rb new file mode 100644 index 00000000..57961826 --- /dev/null +++ b/spec/lib/completions/cancel_manager_spec.rb @@ -0,0 +1,106 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::CancelManager do + fab!(:model) { Fabricate(:anthropic_model, name: "test-model") } + + it "can stop monitoring for cancellation cleanly" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + cancel_manager.start_monitor(delay: 100) { false } + expect(cancel_manager.monitor_thread).not_to be_nil + cancel_manager.stop_monitor + expect(cancel_manager.cancelled?).to eq(false) + expect(cancel_manager.monitor_thread).to be_nil + end + + it "can monitor for cancellation" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + results = [true, false, false] + + cancel_manager.start_monitor(delay: 0) { results.pop } + + wait_for { cancel_manager.cancelled? == true } + wait_for { cancel_manager.monitor_thread.nil? } + + expect(cancel_manager.cancelled?).to eq(true) + expect(cancel_manager.monitor_thread).to be_nil + end + + it "should do nothing when cancel manager is already cancelled" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + cancel_manager.cancel! + + llm = model.to_llm + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a test bot", + messages: [{ type: :user, content: "hello" }], + ) + + result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager) + expect(result).to be_nil + end + + it "should be able to cancel a completion" do + # Start an HTTP server that hangs indefinitely + server = TCPServer.new("127.0.0.1", 0) + port = server.addr[1] + + begin + thread = + Thread.new do + loop do + begin + _client = server.accept + sleep(30) # Hold the connection longer than the test will run + break + rescue StandardError + # Server closed + break + end + end + end + + # Create a model that points to our hanging server + model.update!(url: "http://127.0.0.1:#{port}") + + cancel_manager = DiscourseAi::Completions::CancelManager.new + + completion_thread = + Thread.new do + llm = model.to_llm + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a test bot", + messages: [{ type: :user, content: "hello" }], + ) + + result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager) + expect(result).to be_nil + expect(cancel_manager.cancelled).to eq(true) + end + + wait_for { cancel_manager.callbacks.size == 1 } + + cancel_manager.cancel! + completion_thread.join(2) + + expect(completion_thread).not_to be_alive + ensure + begin + server.close + rescue StandardError + nil + end + begin + thread.kill + rescue StandardError + nil + end + begin + completion_thread&.kill + rescue StandardError + nil + end + end + end +end diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 130c735b..8ad95b95 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -188,9 +188,11 @@ class EndpointsCompliance mock.stub_streamed_simple_call(dialect.translate) do completion_response = +"" - endpoint.perform_completion!(dialect, user) do |partial, cancel| + cancel_manager = DiscourseAi::Completions::CancelManager.new + + endpoint.perform_completion!(dialect, user, cancel_manager: cancel_manager) do |partial| completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + cancel_manager.cancel! if completion_response.split(" ").length == 2 end expect(AiApiAuditLog.count).to eq(1) @@ -212,12 +214,14 @@ class EndpointsCompliance prompt = generic_prompt(tools: [mock.tool]) a_dialect = dialect(prompt: prompt) + cancel_manager = DiscourseAi::Completions::CancelManager.new + mock.stub_streamed_tool_call(a_dialect.translate) do buffered_partial = [] - endpoint.perform_completion!(a_dialect, user) do |partial, cancel| + endpoint.perform_completion!(a_dialect, user, cancel_manager: cancel_manager) do |partial| buffered_partial << partial - cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall) + cancel_manager if partial.is_a?(DiscourseAi::Completions::ToolCall) end expect(buffered_partial).to eq([mock.invocation_response]) diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 4d21c94d..e9ad0ac3 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -1136,14 +1136,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do split = body.split("|") + cancel_manager = DiscourseAi::Completions::CancelManager.new + count = 0 DiscourseAi::AiBot::PostStreamer.on_callback = proc do |callback| count += 1 - if count == 2 - last_post = third_post.topic.posts.order(:id).last - Discourse.redis.del("gpt_cancel:#{last_post.id}") - end + cancel_manager.cancel! if count == 2 raise "this should not happen" if count > 2 end @@ -1155,13 +1154,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do ) # we are going to need to use real data here cause we want to trigger the # base endpoint to cancel part way through - playground.reply_to(third_post) + playground.reply_to(third_post, cancel_manager: cancel_manager) end last_post = third_post.topic.posts.order(:id).last - # not Hello123, we cancelled at 1 which means we may get 2 and then be done - expect(last_post.raw).to eq("Hello12") + # not Hello123, we cancelled at 1 + expect(last_post.raw).to eq("Hello1") end end diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb index e3c96e34..fa81876e 100644 --- a/spec/lib/personas/persona_spec.rb +++ b/spec/lib/personas/persona_spec.rb @@ -218,32 +218,28 @@ RSpec.describe DiscourseAi::Personas::Persona do SiteSetting.ai_google_custom_search_cx = "abc123" # should be ordered by priority and then alpha - expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to eq( - [ - DiscourseAi::Personas::General, - DiscourseAi::Personas::Artist, - DiscourseAi::Personas::Creative, - DiscourseAi::Personas::DiscourseHelper, - DiscourseAi::Personas::GithubHelper, - DiscourseAi::Personas::Researcher, - DiscourseAi::Personas::SettingsExplorer, - DiscourseAi::Personas::SqlHelper, - ], + expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to contain_exactly( + DiscourseAi::Personas::General, + DiscourseAi::Personas::Artist, + DiscourseAi::Personas::Creative, + DiscourseAi::Personas::DiscourseHelper, + DiscourseAi::Personas::GithubHelper, + DiscourseAi::Personas::Researcher, + DiscourseAi::Personas::SettingsExplorer, + DiscourseAi::Personas::SqlHelper, ) # it should allow staff access to WebArtifactCreator - expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to eq( - [ - DiscourseAi::Personas::General, - DiscourseAi::Personas::Artist, - DiscourseAi::Personas::Creative, - DiscourseAi::Personas::DiscourseHelper, - DiscourseAi::Personas::GithubHelper, - DiscourseAi::Personas::Researcher, - DiscourseAi::Personas::SettingsExplorer, - DiscourseAi::Personas::SqlHelper, - DiscourseAi::Personas::WebArtifactCreator, - ], + expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to contain_exactly( + DiscourseAi::Personas::General, + DiscourseAi::Personas::Artist, + DiscourseAi::Personas::Creative, + DiscourseAi::Personas::DiscourseHelper, + DiscourseAi::Personas::GithubHelper, + DiscourseAi::Personas::Researcher, + DiscourseAi::Personas::SettingsExplorer, + DiscourseAi::Personas::SqlHelper, + DiscourseAi::Personas::WebArtifactCreator, ) # omits personas if key is missing diff --git a/spec/lib/personas/tools/researcher_spec.rb b/spec/lib/personas/tools/researcher_spec.rb new file mode 100644 index 00000000..51227001 --- /dev/null +++ b/spec/lib/personas/tools/researcher_spec.rb @@ -0,0 +1,109 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Personas::Tools::Researcher do + before { SearchIndexer.enable } + after { SearchIndexer.disable } + + fab!(:llm_model) + let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } + let(:progress_blk) { Proc.new {} } + + fab!(:admin) + fab!(:user) + fab!(:category) { Fabricate(:category, name: "research-category") } + fab!(:tag_research) { Fabricate(:tag, name: "research") } + fab!(:tag_data) { Fabricate(:tag, name: "data") } + + fab!(:topic_with_tags) { Fabricate(:topic, category: category, tags: [tag_research, tag_data]) } + fab!(:post) { Fabricate(:post, topic: topic_with_tags) } + + before { SiteSetting.ai_bot_enabled = true } + + describe "#invoke" do + it "returns filter information and result count" do + researcher = + described_class.new( + { filter: "tag:research after:2023", goals: "analyze post patterns", dry_run: true }, + bot_user: bot_user, + llm: llm, + context: DiscourseAi::Personas::BotContext.new(user: user, post: post), + ) + + results = researcher.invoke(&progress_blk) + + expect(results[:filter]).to eq("tag:research after:2023") + expect(results[:goals]).to eq("analyze post patterns") + expect(results[:dry_run]).to eq(true) + expect(results[:number_of_results]).to be > 0 + expect(researcher.filter).to eq("tag:research after:2023") + expect(researcher.result_count).to be > 0 + end + + it "handles empty filters" do + researcher = + described_class.new({ goals: "analyze all content" }, bot_user: bot_user, llm: llm) + + results = researcher.invoke(&progress_blk) + + expect(results[:error]).to eq("No filter provided") + end + + it "accepts max_results option" do + researcher = + described_class.new( + { filter: "category:research-category" }, + persona_options: { + "max_results" => "50", + }, + bot_user: bot_user, + llm: llm, + ) + + expect(researcher.options[:max_results]).to eq(50) + end + + it "returns correct results for non-dry-run with filtered posts" do + # Stage 2 topics, each with 2 posts + topics = Array.new(2) { Fabricate(:topic, category: category, tags: [tag_research]) } + topics.flat_map do |topic| + [ + Fabricate(:post, topic: topic, raw: "Relevant content 1", user: user), + Fabricate(:post, topic: topic, raw: "Relevant content 2", user: admin), + ] + end + + # Filter to posts by user in research-category + researcher = + described_class.new( + { + filter: "category:research-category @#{user.username}", + goals: "find relevant content", + dry_run: false, + }, + bot_user: bot_user, + llm: llm, + context: DiscourseAi::Personas::BotContext.new(user: user, post: post), + ) + + responses = 10.times.map { |i| ["Found: Relevant content #{i + 1}"] } + results = nil + + last_progress = nil + progress_blk = Proc.new { |response| last_progress = response } + + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do + researcher.llm = llm_model.to_llm + results = researcher.invoke(&progress_blk) + end + + expect(last_progress).to include("find relevant content") + expect(last_progress).to include("category:research-category") + + expect(results[:dry_run]).to eq(false) + expect(results[:goals]).to eq("find relevant content") + expect(results[:filter]).to eq("category:research-category @#{user.username}") + expect(results[:results].first).to include("Found: Relevant content 1") + end + end +end diff --git a/spec/lib/utils/research/filter_spec.rb b/spec/lib/utils/research/filter_spec.rb new file mode 100644 index 00000000..655d1705 --- /dev/null +++ b/spec/lib/utils/research/filter_spec.rb @@ -0,0 +1,142 @@ +# frozen_string_literal: true + +describe DiscourseAi::Utils::Research::Filter do + describe "integration tests" do + before_all { SiteSetting.min_topic_title_length = 3 } + + fab!(:user) + + fab!(:feature_tag) { Fabricate(:tag, name: "feature") } + fab!(:bug_tag) { Fabricate(:tag, name: "bug") } + + fab!(:announcement_category) { Fabricate(:category, name: "Announcements") } + fab!(:feedback_category) { Fabricate(:category, name: "Feedback") } + + fab!(:feature_topic) do + Fabricate( + :topic, + user: user, + tags: [feature_tag], + category: announcement_category, + title: "New Feature Discussion", + ) + end + + fab!(:bug_topic) do + Fabricate( + :topic, + tags: [bug_tag], + user: user, + category: announcement_category, + title: "Bug Report", + ) + end + + fab!(:feature_bug_topic) do + Fabricate( + :topic, + tags: [feature_tag, bug_tag], + user: user, + category: feedback_category, + title: "Feature with Bug", + ) + end + + fab!(:no_tag_topic) do + Fabricate(:topic, user: user, category: feedback_category, title: "General Discussion") + end + + fab!(:feature_post) { Fabricate(:post, topic: feature_topic, user: user) } + fab!(:bug_post) { Fabricate(:post, topic: bug_topic, user: user) } + fab!(:feature_bug_post) { Fabricate(:post, topic: feature_bug_topic, user: user) } + fab!(:no_tag_post) { Fabricate(:post, topic: no_tag_topic, user: user) } + + describe "tag filtering" do + it "correctly filters posts by tags" do + filter = described_class.new("tag:feature") + expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, feature_bug_post.id) + + filter = described_class.new("tag:feature,bug") + expect(filter.search.pluck(:id)).to contain_exactly( + feature_bug_post.id, + bug_post.id, + feature_post.id, + ) + + filter = described_class.new("tags:bug") + expect(filter.search.pluck(:id)).to contain_exactly(bug_post.id, feature_bug_post.id) + + filter = described_class.new("tag:nonexistent") + expect(filter.search.count).to eq(0) + end + end + + describe "category filtering" do + it "correctly filters posts by categories" do + filter = described_class.new("category:Announcements") + expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, bug_post.id) + + filter = described_class.new("category:Announcements,Feedback") + expect(filter.search.pluck(:id)).to contain_exactly( + feature_post.id, + bug_post.id, + feature_bug_post.id, + no_tag_post.id, + ) + + filter = described_class.new("categories:Feedback") + expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id, no_tag_post.id) + + filter = described_class.new("category:Feedback tag:feature") + expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id) + end + end + + it "can limit number of results" do + filter = described_class.new("category:Feedback max_results:1", limit: 5) + expect(filter.search.pluck(:id).length).to eq(1) + end + + describe "full text keyword searching" do + before_all { SearchIndexer.enable } + fab!(:post_with_apples) do + Fabricate(:post, raw: "This post contains apples", topic: feature_topic, user: user) + end + + fab!(:post_with_bananas) do + Fabricate(:post, raw: "This post mentions bananas", topic: bug_topic, user: user) + end + + fab!(:post_with_both) do + Fabricate( + :post, + raw: "This post has apples and bananas", + topic: feature_bug_topic, + user: user, + ) + end + + fab!(:post_with_none) do + Fabricate(:post, raw: "No fruits here", topic: no_tag_topic, user: user) + end + + it "correctly filters posts by full text keywords" do + filter = described_class.new("keywords:apples") + expect(filter.search.pluck(:id)).to contain_exactly(post_with_apples.id, post_with_both.id) + + filter = described_class.new("keywords:bananas") + expect(filter.search.pluck(:id)).to contain_exactly(post_with_bananas.id, post_with_both.id) + + filter = described_class.new("keywords:apples,bananas") + expect(filter.search.pluck(:id)).to contain_exactly( + post_with_apples.id, + post_with_bananas.id, + post_with_both.id, + ) + + filter = described_class.new("keywords:oranges") + expect(filter.search.count).to eq(0) + end + end + end +end diff --git a/spec/lib/utils/research/llm_formatter_spec.rb b/spec/lib/utils/research/llm_formatter_spec.rb new file mode 100644 index 00000000..edc10363 --- /dev/null +++ b/spec/lib/utils/research/llm_formatter_spec.rb @@ -0,0 +1,74 @@ +# frozen_string_literal: true +# +describe DiscourseAi::Utils::Research::LlmFormatter do + fab!(:user) { Fabricate(:user, username: "test_user") } + fab!(:topic) { Fabricate(:topic, title: "This is a Test Topic", user: user) } + fab!(:post) { Fabricate(:post, topic: topic, user: user) } + let(:tokenizer) { DiscourseAi::Tokenizer::OpenAiTokenizer } + let(:filter) { DiscourseAi::Utils::Research::Filter.new("@#{user.username}") } + + describe "#truncate_if_needed" do + it "returns original content when under token limit" do + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 100, + ) + + short_text = "This is a short post" + expect(formatter.send(:truncate_if_needed, short_text)).to eq(short_text) + end + + it "truncates content when over token limit" do + # Create a post with content that will exceed our token limit + long_text = ("word " * 200).strip + + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 50, + ) + + truncated = formatter.send(:truncate_if_needed, long_text) + + expect(truncated).to include("... elided 150 tokens ...") + expect(truncated).to_not eq(long_text) + + # Should have roughly 25 words before and 25 after (half of max_tokens_per_post) + first_chunk = truncated.split("\n\n")[0] + expect(first_chunk.split(" ").length).to be_within(5).of(25) + + last_chunk = truncated.split("\n\n")[2] + expect(last_chunk.split(" ").length).to be_within(5).of(25) + end + end + + describe "#format_post" do + it "formats posts with truncation for long content" do + # Set up a post with long content + long_content = ("word " * 200).strip + long_post = Fabricate(:post, raw: long_content, topic: topic, user: user) + + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 50, + ) + + formatted = formatter.send(:format_post, long_post) + + # Should have standard formatting elements + expect(formatted).to include("## Post by #{user.username}") + expect(formatted).to include("Post url: /t/-/#{long_post.topic_id}/#{long_post.post_number}") + + # Should include truncation marker + expect(formatted).to include("... elided 150 tokens ...") + end + end +end