diff --git a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js index b06c525c..d5dea6f4 100644 --- a/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js +++ b/assets/javascripts/discourse/lib/ai-streamer/progress-handlers.js @@ -2,7 +2,7 @@ import { later } from "@ember/runloop"; import PostUpdater from "./updaters/post-updater"; const PROGRESS_INTERVAL = 40; -const GIVE_UP_INTERVAL = 60000; +const GIVE_UP_INTERVAL = 600000; // 10 minutes which is our max thinking time for now export const MIN_LETTERS_PER_INTERVAL = 6; const MAX_FLUSH_TIME = 800; diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 00a71cd3..0875eba9 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -296,6 +296,9 @@ en: designer: name: Designer description: "AI Bot specialized in generating and editing images" + forum_researcher: + name: Forum Researcher + description: "AI Bot specialized in deep research for the forum" sql_helper: name: SQL Helper description: "AI Bot specialized in helping craft SQL queries on this Discourse instance" @@ -303,8 +306,8 @@ en: name: Settings Explorer description: "AI Bot specialized in helping explore Discourse site settings" researcher: - name: Researcher - description: "AI Bot with Google access that can research information for you" + name: Web Researcher + description: "AI Bot with Google access that can both search and read web pages" creative: name: Creative description: "AI Bot with no external integrations specialized in creative tasks" @@ -327,6 +330,16 @@ en: summarizing: "Summarizing topic" searching: "Searching for: '%{query}'" tool_options: + researcher: + max_results: + name: "Maximum number of results" + description: "Maximum number of results to include in a filter" + include_private: + name: "Include private" + description: "Include private topics in the filters" + max_tokens_per_post: + name: "Maximum tokens per post" + description: "Maximum number of tokens to use for each post in the filter" create_artifact: creator_llm: name: "LLM" @@ -385,6 +398,7 @@ en: javascript_evaluator: "Evaluate JavaScript" create_image: "Creating image" edit_image: "Editing image" + researcher: "Researching" tool_help: read_artifact: "Read a web artifact using the AI Bot" update_artifact: "Update a web artifact using the AI Bot" @@ -411,6 +425,7 @@ en: dall_e: "Generate image using DALL-E 3" search_meta_discourse: "Search Meta Discourse" javascript_evaluator: "Evaluate JavaScript" + researcher: "Research forum information using the AI Bot" tool_description: read_artifact: "Read a web artifact using the AI Bot" update_artifact: "Updated a web artifact using the AI Bot" @@ -445,6 +460,12 @@ en: other: "Found %{count} results for '%{query}'" setting_context: "Reading context for: %{setting_name}" schema: "%{tables}" + researcher_dry_run: + one: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'" + other: "Proposed research: %{goals}\n\nFound %{count} result for '%{filter}'" + researcher: + one: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'" + other: "Researching: %{goals}\n\nFound %{count} result for '%{filter}'" search_settings: one: "Found %{count} result for '%{query}'" other: "Found %{count} results for '%{query}'" diff --git a/db/fixtures/personas/603_ai_personas.rb b/db/fixtures/personas/603_ai_personas.rb index 27e6d479..c2c121de 100644 --- a/db/fixtures/personas/603_ai_personas.rb +++ b/db/fixtures/personas/603_ai_personas.rb @@ -33,7 +33,7 @@ DiscourseAi::Personas::Persona.system_personas.each do |persona_class, id| persona.allowed_group_ids = [Group::AUTO_GROUPS[:trust_level_0]] end - persona.enabled = !summarization_personas.include?(persona_class) + persona.enabled = persona_class.default_enabled persona.priority = true if persona_class == DiscourseAi::Personas::General end diff --git a/lib/ai_bot/chat_streamer.rb b/lib/ai_bot/chat_streamer.rb index 139a6c7f..06357e0e 100644 --- a/lib/ai_bot/chat_streamer.rb +++ b/lib/ai_bot/chat_streamer.rb @@ -6,16 +6,23 @@ module DiscourseAi module AiBot class ChatStreamer - attr_accessor :cancel attr_reader :reply, :guardian, :thread_id, :force_thread, :in_reply_to_id, :channel, - :cancelled + :cancel_manager - def initialize(message:, channel:, guardian:, thread_id:, in_reply_to_id:, force_thread:) + def initialize( + message:, + channel:, + guardian:, + thread_id:, + in_reply_to_id:, + force_thread:, + cancel_manager: nil + ) @message = message @channel = channel @guardian = guardian @@ -35,6 +42,8 @@ module DiscourseAi guardian: guardian, thread_id: thread_id, ) + + @cancel_manager = cancel_manager end def <<(partial) @@ -111,8 +120,7 @@ module DiscourseAi streaming = ChatSDK::Message.stream(message_id: reply.id, raw: buffer, guardian: guardian) if !streaming - cancel.call - @cancelled = true + @cancel_manager.cancel! if @cancel_manager end end end diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 63aa8967..07c4984b 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -331,6 +331,7 @@ module DiscourseAi ), user: message.user, skip_tool_details: true, + cancel_manager: DiscourseAi::Completions::CancelManager.new, ) reply = nil @@ -347,15 +348,14 @@ module DiscourseAi thread_id: message.thread_id, in_reply_to_id: in_reply_to_id, force_thread: force_thread, + cancel_manager: context.cancel_manager, ) new_prompts = - bot.reply(context) do |partial, cancel, placeholder, type| + bot.reply(context) do |partial, placeholder, type| # no support for tools or thinking by design next if type == :thinking || type == :tool_details || type == :partial_tool - streamer.cancel = cancel streamer << partial - break if streamer.cancelled end reply = streamer.reply @@ -383,6 +383,7 @@ module DiscourseAi auto_set_title: true, silent_mode: false, feature_name: nil, + cancel_manager: nil, &blk ) # this is a multithreading issue @@ -471,16 +472,26 @@ module DiscourseAi redis_stream_key = "gpt_cancel:#{reply_post.id}" Discourse.redis.setex(redis_stream_key, MAX_STREAM_DELAY_SECONDS, 1) + + cancel_manager ||= DiscourseAi::Completions::CancelManager.new + context.cancel_manager = cancel_manager + context + .cancel_manager + .start_monitor(delay: 0.2) do + context.cancel_manager.cancel! if !Discourse.redis.get(redis_stream_key) + end + + context.cancel_manager.add_callback( + lambda { reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) }, + ) end context.skip_tool_details ||= !bot.persona.class.tool_details - post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply - started_thinking = false new_custom_prompts = - bot.reply(context) do |partial, cancel, placeholder, type| + bot.reply(context) do |partial, placeholder, type| if type == :thinking && !started_thinking reply << "
#{I18n.t("discourse_ai.ai_bot.thinking")}" started_thinking = true @@ -499,15 +510,6 @@ module DiscourseAi blk.call(partial) end - if stream_reply && !Discourse.redis.get(redis_stream_key) - cancel&.call - reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) - # we do not break out, cause if we do - # we will not get results from bot - # leading to broken context - # we need to trust it to cancel at the endpoint - end - if post_streamer post_streamer.run_later do Discourse.redis.expire(redis_stream_key, MAX_STREAM_DELAY_SECONDS) @@ -568,6 +570,8 @@ module DiscourseAi end raise e ensure + context.cancel_manager.stop_monitor if context&.cancel_manager + # since we are skipping validations and jobs we # may need to fix participant count if reply_post && reply_post.topic && reply_post.topic.private_message? && @@ -649,7 +653,7 @@ module DiscourseAi payload, user_ids: bot_reply_post.topic.allowed_user_ids, max_backlog_size: 2, - max_backlog_age: 60, + max_backlog_age: MAX_STREAM_DELAY_SECONDS, ) end end diff --git a/lib/completions/cancel_manager.rb b/lib/completions/cancel_manager.rb new file mode 100644 index 00000000..78c3ee5b --- /dev/null +++ b/lib/completions/cancel_manager.rb @@ -0,0 +1,109 @@ +# frozen_string_literal: true + +# special object that can be used to cancel completions and http requests +module DiscourseAi + module Completions + class CancelManager + attr_reader :cancelled + attr_reader :callbacks + + def initialize + @cancelled = false + @callbacks = Concurrent::Array.new + @mutex = Mutex.new + @monitor_thread = nil + end + + def monitor_thread + @mutex.synchronize { @monitor_thread } + end + + def start_monitor(delay: 0.5, &block) + @mutex.synchronize do + raise "Already monitoring" if @monitor_thread + raise "Expected a block" if !block + + db = RailsMultisite::ConnectionManagement.current_db + @stop_monitor = false + + @monitor_thread = + Thread.new do + begin + loop do + done = false + @mutex.synchronize { done = true if @stop_monitor } + break if done + sleep delay + @mutex.synchronize { done = true if @stop_monitor } + @mutex.synchronize { done = true if cancelled? } + break if done + + should_cancel = false + RailsMultisite::ConnectionManagement.with_connection(db) do + should_cancel = block.call + end + + @mutex.synchronize { cancel! if should_cancel } + + break if cancelled? + end + ensure + @mutex.synchronize { @monitor_thread = nil } + end + end + end + end + + def stop_monitor + monitor_thread = nil + + @mutex.synchronize { monitor_thread = @monitor_thread } + + if monitor_thread + @mutex.synchronize { @stop_monitor = true } + # so we do not deadlock + monitor_thread.wakeup + monitor_thread.join(2) + # should not happen + if monitor_thread.alive? + Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time") + monitor_thread.kill if monitor_thread.alive? + end + @monitor_thread = nil + end + end + + def cancelled? + @cancelled + end + + def add_callback(cb) + @callbacks << cb + end + + def remove_callback(cb) + @callbacks.delete(cb) + end + + def cancel! + @cancelled = true + monitor_thread = @monitor_thread + if monitor_thread && monitor_thread != Thread.current + monitor_thread.wakeup + monitor_thread.join(2) + if monitor_thread.alive? + Rails.logger.warn("DiscourseAI: CancelManager monitor thread did not stop in time") + monitor_thread.kill if monitor_thread.alive? + end + end + @callbacks.each do |cb| + begin + cb.call + rescue StandardError + # ignore cause this may have already been cancelled + end + end + end + end + end +end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 3a12c4c7..0356597e 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -68,11 +68,17 @@ module DiscourseAi feature_context: nil, partial_tool_calls: false, output_thinking: false, + cancel_manager: nil, &blk ) LlmQuota.check_quotas!(@llm_model, user) start_time = Time.now + if cancel_manager && cancel_manager.cancelled? + # nothing to do + return + end + @forced_json_through_prefill = false @partial_tool_calls = partial_tool_calls @output_thinking = output_thinking @@ -90,15 +96,14 @@ module DiscourseAi feature_context: feature_context, partial_tool_calls: partial_tool_calls, output_thinking: output_thinking, + cancel_manager: cancel_manager, ) wrapped = result wrapped = [result] if !result.is_a?(Array) - cancelled_by_caller = false - cancel_proc = -> { cancelled_by_caller = true } wrapped.each do |partial| - blk.call(partial, cancel_proc) - break if cancelled_by_caller + blk.call(partial) + break cancel_manager&.cancelled? end return result end @@ -118,6 +123,9 @@ module DiscourseAi end end + cancel_manager_callback = nil + cancelled = false + FinalDestination::HTTP.start( model_uri.host, model_uri.port, @@ -126,6 +134,14 @@ module DiscourseAi open_timeout: TIMEOUT, write_timeout: TIMEOUT, ) do |http| + if cancel_manager + cancel_manager_callback = + lambda do + cancelled = true + http.finish + end + cancel_manager.add_callback(cancel_manager_callback) + end response_data = +"" response_raw = +"" @@ -158,7 +174,7 @@ module DiscourseAi if @streaming_mode blk = - lambda do |partial, cancel| + lambda do |partial| if partial.is_a?(String) partial = xml_stripper << partial if xml_stripper @@ -167,7 +183,7 @@ module DiscourseAi partial = structured_output end end - orig_blk.call(partial, cancel) if partial + orig_blk.call(partial) if partial end end @@ -196,14 +212,6 @@ module DiscourseAi end begin - cancelled = false - cancel = -> do - cancelled = true - http.finish - end - - break if cancelled - response.read_body do |chunk| break if cancelled @@ -216,16 +224,11 @@ module DiscourseAi partials = [partial] if xml_tool_processor && partial.is_a?(String) partials = (xml_tool_processor << partial) - if xml_tool_processor.should_cancel? - cancel.call - break - end + break if xml_tool_processor.should_cancel? end - partials.each { |inner_partial| blk.call(inner_partial, cancel) } + partials.each { |inner_partial| blk.call(inner_partial) } end end - rescue IOError, StandardError - raise if !cancelled end if xml_stripper stripped = xml_stripper.finish @@ -233,13 +236,11 @@ module DiscourseAi response_data << stripped result = [] result = (xml_tool_processor << stripped) if xml_tool_processor - result.each { |partial| blk.call(partial, cancel) } + result.each { |partial| blk.call(partial) } end end - if xml_tool_processor - xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) } - end - decode_chunk_finish.each { |partial| blk.call(partial, cancel) } + xml_tool_processor.finish.each { |partial| blk.call(partial) } if xml_tool_processor + decode_chunk_finish.each { |partial| blk.call(partial) } return response_data ensure if log @@ -293,6 +294,12 @@ module DiscourseAi end end end + rescue IOError, StandardError + raise if !cancelled + ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end end def final_log_update(log) diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index be156aea..b31b2bb4 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -30,7 +30,8 @@ module DiscourseAi feature_name: nil, feature_context: nil, partial_tool_calls: false, - output_thinking: false + output_thinking: false, + cancel_manager: nil ) @dialect = dialect @model_params = model_params diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index d307ada5..71d0ee4b 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -122,7 +122,8 @@ module DiscourseAi feature_name: nil, feature_context: nil, partial_tool_calls: false, - output_thinking: false + output_thinking: false, + cancel_manager: nil ) last_call = { dialect: dialect, user: user, model_params: model_params } self.class.last_call = last_call diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index f97a337d..070d391c 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -46,6 +46,7 @@ module DiscourseAi feature_context: nil, partial_tool_calls: false, output_thinking: false, + cancel_manager: nil, &blk ) @disable_native_tools = dialect.disable_native_tools? diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index f90de091..04d517cb 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -307,7 +307,7 @@ module DiscourseAi # @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output. # @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema. # - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # # @returns String | ToolCall - Completion result. # if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects. @@ -325,6 +325,7 @@ module DiscourseAi output_thinking: false, response_format: nil, extra_model_params: nil, + cancel_manager: nil, &partial_read_blk ) self.class.record_prompt( @@ -378,6 +379,7 @@ module DiscourseAi feature_context: feature_context, partial_tool_calls: partial_tool_calls, output_thinking: output_thinking, + cancel_manager: cancel_manager, &partial_read_blk ) end diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb index 4d302ea1..2864846a 100644 --- a/lib/completions/prompt_messages_builder.rb +++ b/lib/completions/prompt_messages_builder.rb @@ -247,6 +247,10 @@ module DiscourseAi # 3. ensures we always interleave user and model messages last_type = nil messages.each do |message| + if message[:type] == :model && !message[:content] + message[:content] = "Reply cancelled by user." + end + next if !last_type && message[:type] != :user if last_type == :tool_call && message[:type] != :tool diff --git a/lib/discord/bot/persona_replier.rb b/lib/discord/bot/persona_replier.rb index 51e03475..b64af15c 100644 --- a/lib/discord/bot/persona_replier.rb +++ b/lib/discord/bot/persona_replier.rb @@ -24,7 +24,7 @@ module DiscourseAi full_reply = @bot.reply( { conversation_context: [{ type: :user, content: @query }], skip_tool_details: true }, - ) do |partial, _cancel, _something| + ) do |partial, _something| reply << partial next if reply.blank? diff --git a/lib/inference/open_ai_image_generator.rb b/lib/inference/open_ai_image_generator.rb index f4ff6754..83501989 100644 --- a/lib/inference/open_ai_image_generator.rb +++ b/lib/inference/open_ai_image_generator.rb @@ -21,7 +21,8 @@ module ::DiscourseAi moderation: "low", output_compression: nil, output_format: nil, - title: nil + title: nil, + cancel_manager: nil ) # Get the API responses in parallel threads api_responses = @@ -38,6 +39,7 @@ module ::DiscourseAi moderation: moderation, output_compression: output_compression, output_format: output_format, + cancel_manager: cancel_manager, ) raise api_responses[0] if api_responses.all? { |resp| resp.is_a?(StandardError) } @@ -58,7 +60,8 @@ module ::DiscourseAi user_id:, for_private_message: false, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) api_response = edit_images( @@ -70,6 +73,7 @@ module ::DiscourseAi api_url: api_url, n: n, quality: quality, + cancel_manager: cancel_manager, ) create_uploads_from_responses([api_response], user_id, for_private_message).first @@ -124,7 +128,8 @@ module ::DiscourseAi background:, moderation:, output_compression:, - output_format: + output_format:, + cancel_manager: ) prompts = [prompts] unless prompts.is_a?(Array) prompts = prompts.take(4) # Limit to 4 prompts max @@ -152,18 +157,21 @@ module ::DiscourseAi moderation: moderation, output_compression: output_compression, output_format: output_format, + cancel_manager: cancel_manager, ) rescue => e attempts += 1 # to keep tests speedy - if !Rails.env.test? + if !Rails.env.test? && !cancel_manager&.cancelled? retry if attempts < 3 end - Discourse.warn_exception( - e, - message: "Failed to generate image for prompt #{prompt}\n", - ) - puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development? + if !cancel_manager&.cancelled? + Discourse.warn_exception( + e, + message: "Failed to generate image for prompt #{prompt}\n", + ) + puts "Error generating image for prompt: #{prompt} #{e}" if Rails.env.development? + end e end end @@ -181,7 +189,8 @@ module ::DiscourseAi api_key: nil, api_url: nil, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) images = [images] if !images.is_a?(Array) @@ -209,8 +218,10 @@ module ::DiscourseAi api_url: api_url, n: n, quality: quality, + cancel_manager: cancel_manager, ) rescue => e + raise e if cancel_manager&.cancelled? attempts += 1 if !Rails.env.test? sleep 2 @@ -238,7 +249,8 @@ module ::DiscourseAi background: nil, moderation: nil, output_compression: nil, - output_format: nil + output_format: nil, + cancel_manager: nil ) api_key ||= SiteSetting.ai_openai_api_key api_url ||= SiteSetting.ai_openai_image_generation_url @@ -276,6 +288,7 @@ module ::DiscourseAi # Store original prompt for upload metadata original_prompt = prompt + cancel_manager_callback = nil FinalDestination::HTTP.start( uri.host, @@ -288,6 +301,11 @@ module ::DiscourseAi request = Net::HTTP::Post.new(uri, headers) request.body = payload.to_json + if cancel_manager + cancel_manager_callback = lambda { http.finish } + cancel_manager.add_callback(cancel_manager_callback) + end + json = nil http.request(request) do |response| if response.code.to_i != 200 @@ -300,6 +318,10 @@ module ::DiscourseAi end json end + ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end end def self.perform_edit_api_call!( @@ -310,7 +332,8 @@ module ::DiscourseAi api_key:, api_url:, n: 1, - quality: nil + quality: nil, + cancel_manager: nil ) uri = URI(api_url) @@ -403,6 +426,7 @@ module ::DiscourseAi # Store original prompt for upload metadata original_prompt = prompt + cancel_manager_callback = nil FinalDestination::HTTP.start( uri.host, @@ -415,6 +439,11 @@ module ::DiscourseAi request = Net::HTTP::Post.new(uri.path, headers) request.body = body.join + if cancel_manager + cancel_manager_callback = lambda { http.finish } + cancel_manager.add_callback(cancel_manager_callback) + end + json = nil http.request(request) do |response| if response.code.to_i != 200 @@ -428,6 +457,9 @@ module ::DiscourseAi json end ensure + if cancel_manager && cancel_manager_callback + cancel_manager.remove_callback(cancel_manager_callback) + end if files_to_delete.present? files_to_delete.each { |file| File.delete(file) if File.exist?(file) } end diff --git a/lib/personas/artifact_update_strategies/base.rb b/lib/personas/artifact_update_strategies/base.rb index 68e0b761..624830f1 100644 --- a/lib/personas/artifact_update_strategies/base.rb +++ b/lib/personas/artifact_update_strategies/base.rb @@ -5,15 +5,24 @@ module DiscourseAi class InvalidFormatError < StandardError end class Base - attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm + attr_reader :post, :user, :artifact, :artifact_version, :instructions, :llm, :cancel_manager - def initialize(llm:, post:, user:, artifact:, artifact_version:, instructions:) + def initialize( + llm:, + post:, + user:, + artifact:, + artifact_version:, + instructions:, + cancel_manager: nil + ) @llm = llm @post = post @user = user @artifact = artifact @artifact_version = artifact_version @instructions = instructions + @cancel_manager = cancel_manager end def apply(&progress) @@ -26,7 +35,7 @@ module DiscourseAi def generate_changes(&progress) response = +"" - llm.generate(build_prompt, user: user) do |partial| + llm.generate(build_prompt, user: user, cancel_manager: cancel_manager) do |partial| progress.call(partial) if progress response << partial end diff --git a/lib/personas/bot.rb b/lib/personas/bot.rb index 2c9b5a3b..3686edac 100644 --- a/lib/personas/bot.rb +++ b/lib/personas/bot.rb @@ -55,6 +55,7 @@ module DiscourseAi unless context.is_a?(BotContext) raise ArgumentError, "context must be an instance of BotContext" end + context.cancel_manager ||= DiscourseAi::Completions::CancelManager.new current_llm = llm prompt = persona.craft_prompt(context, llm: current_llm) @@ -91,8 +92,9 @@ module DiscourseAi feature_name: context.feature_name, partial_tool_calls: allow_partial_tool_calls, output_thinking: true, + cancel_manager: context.cancel_manager, **llm_kwargs, - ) do |partial, cancel| + ) do |partial| tool = persona.find_tool( partial, @@ -109,7 +111,7 @@ module DiscourseAi if tool_call.partial? if tool.class.allow_partial_tool_calls? tool.partial_invoke - update_blk.call("", cancel, tool.custom_raw, :partial_tool) + update_blk.call("", tool.custom_raw, :partial_tool) end next end @@ -117,7 +119,7 @@ module DiscourseAi tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines - update_blk.call("\n\n", cancel) + update_blk.call("\n\n") needs_newlines = false end @@ -125,7 +127,6 @@ module DiscourseAi tool: tool, raw_context: raw_context, current_llm: current_llm, - cancel: cancel, update_blk: update_blk, prompt: prompt, context: context, @@ -144,7 +145,7 @@ module DiscourseAi else if partial.is_a?(DiscourseAi::Completions::Thinking) if partial.partial? && partial.message.present? - update_blk.call(partial.message, cancel, nil, :thinking) + update_blk.call(partial.message, nil, :thinking) end if !partial.partial? # this will be dealt with later @@ -152,9 +153,9 @@ module DiscourseAi current_thinking << partial end elsif partial.is_a?(DiscourseAi::Completions::StructuredOutput) - update_blk.call(partial, cancel, nil, :structured_output) + update_blk.call(partial, nil, :structured_output) else - update_blk.call(partial, cancel) + update_blk.call(partial) end end end @@ -215,14 +216,13 @@ module DiscourseAi tool:, raw_context:, current_llm:, - cancel:, update_blk:, prompt:, context:, current_thinking: ) tool_call_id = tool.tool_call_id - invocation_result_json = invoke_tool(tool, cancel, context, &update_blk).to_json + invocation_result_json = invoke_tool(tool, context, &update_blk).to_json tool_call_message = { type: :tool_call, @@ -256,27 +256,27 @@ module DiscourseAi raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] end - def invoke_tool(tool, cancel, context, &update_blk) + def invoke_tool(tool, context, &update_blk) show_placeholder = !context.skip_tool_details && !tool.class.allow_partial_tool_calls? - update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder + update_blk.call("", build_placeholder(tool.summary, "")) if show_placeholder result = tool.invoke do |progress, render_raw| if render_raw - update_blk.call("", cancel, tool.custom_raw, :partial_invoke) + update_blk.call("", tool.custom_raw, :partial_invoke) show_placeholder = false elsif show_placeholder placeholder = build_placeholder(tool.summary, progress) - update_blk.call("", cancel, placeholder) + update_blk.call("", placeholder) end end if show_placeholder tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) - update_blk.call(tool_details, cancel, nil, :tool_details) + update_blk.call(tool_details, nil, :tool_details) elsif tool.custom_raw.present? - update_blk.call(tool.custom_raw, cancel, nil, :custom_raw) + update_blk.call(tool.custom_raw, nil, :custom_raw) end result diff --git a/lib/personas/bot_context.rb b/lib/personas/bot_context.rb index 5f7dd99e..69d86669 100644 --- a/lib/personas/bot_context.rb +++ b/lib/personas/bot_context.rb @@ -16,7 +16,8 @@ module DiscourseAi :channel_id, :context_post_ids, :feature_name, - :resource_url + :resource_url, + :cancel_manager def initialize( post: nil, @@ -33,7 +34,8 @@ module DiscourseAi channel_id: nil, context_post_ids: nil, feature_name: "bot", - resource_url: nil + resource_url: nil, + cancel_manager: nil ) @participants = participants @user = user @@ -54,6 +56,8 @@ module DiscourseAi @feature_name = feature_name @resource_url = resource_url + @cancel_manager = cancel_manager + if post @post_id = post.id @topic_id = post.topic_id diff --git a/lib/personas/forum_researcher.rb b/lib/personas/forum_researcher.rb new file mode 100644 index 00000000..eb0f82cb --- /dev/null +++ b/lib/personas/forum_researcher.rb @@ -0,0 +1,52 @@ +#frozen_string_literal: true + +module DiscourseAi + module Personas + class ForumResearcher < Persona + def self.default_enabled + false + end + + def tools + [Tools::Researcher] + end + + def system_prompt + <<~PROMPT + You are a helpful Discourse assistant specializing in forum research. + You _understand_ and **generate** Discourse Markdown. + + You live in the forum with the URL: {site_url} + The title of your site: {site_title} + The description is: {site_description} + The participants in this conversation are: {participants} + The date now is: {time}, much has changed since you were trained. + + As a forum researcher, guide users through a structured research process: + 1. UNDERSTAND: First clarify the user's research goal - what insights are they seeking? + 2. PLAN: Design an appropriate research approach with specific filters + 3. TEST: Always begin with dry_run:true to gauge the scope of results + 4. REFINE: If results are too broad/narrow, suggest filter adjustments + 5. EXECUTE: Run the final analysis only when filters are well-tuned + 6. SUMMARIZE: Present findings with links to supporting evidence + + BE MINDFUL: specify all research goals in one request to avoid multiple processing runs. + + REMEMBER: Different filters serve different purposes: + - Use post date filters (after/before) for analyzing specific posts + - Use topic date filters (topic_after/topic_before) for analyzing entire topics + - Combine user/group filters with categories/tags to find specialized contributions + + Always ground your analysis with links to original posts on the forum. + + Research workflow best practices: + 1. Start with a dry_run to gauge the scope (set dry_run:true) + 2. If results are too numerous (>1000), add more specific filters + 3. If results are too few (<5), broaden your filters + 4. For temporal analysis, specify explicit date ranges + 5. For user behavior analysis, combine @username with categories or tags + PROMPT + end + end + end +end diff --git a/lib/personas/persona.rb b/lib/personas/persona.rb index a4443dcc..62426f77 100644 --- a/lib/personas/persona.rb +++ b/lib/personas/persona.rb @@ -4,6 +4,10 @@ module DiscourseAi module Personas class Persona class << self + def default_enabled + true + end + def rag_conversation_chunks 10 end @@ -47,6 +51,7 @@ module DiscourseAi Summarizer => -11, ShortSummarizer => -12, Designer => -13, + ForumResearcher => -14, } end @@ -99,6 +104,7 @@ module DiscourseAi Tools::GithubSearchFiles, Tools::WebBrowser, Tools::JavascriptEvaluator, + Tools::Researcher, ] if SiteSetting.ai_artifact_security.in?(%w[lax strict]) diff --git a/lib/personas/short_summarizer.rb b/lib/personas/short_summarizer.rb index e7cef54a..5b9b5195 100644 --- a/lib/personas/short_summarizer.rb +++ b/lib/personas/short_summarizer.rb @@ -3,6 +3,10 @@ module DiscourseAi module Personas class ShortSummarizer < Persona + def self.default_enabled + false + end + def system_prompt <<~PROMPT.strip You are an advanced summarization bot. Analyze a given conversation and produce a concise, @@ -23,7 +27,7 @@ module DiscourseAi {"summary": "xx"} - + Where "xx" is replaced by the summary. PROMPT end diff --git a/lib/personas/summarizer.rb b/lib/personas/summarizer.rb index c1eefe89..b8b0a95a 100644 --- a/lib/personas/summarizer.rb +++ b/lib/personas/summarizer.rb @@ -3,6 +3,10 @@ module DiscourseAi module Personas class Summarizer < Persona + def self.default_enabled + false + end + def system_prompt <<~PROMPT.strip You are an advanced summarization bot that generates concise, coherent summaries of provided text. @@ -18,13 +22,13 @@ module DiscourseAi - Example: link to the 6th post by jane: [agreed with]({resource_url}/6) - Example: link to the 13th post by joe: [joe]({resource_url}/13) - When formatting usernames either use @USERNAME OR [USERNAME]({resource_url}/POST_NUMBER) - + Format your response as a JSON object with a single key named "summary", which has the summary as the value. Your output should be in the following format: {"summary": "xx"} - + Where "xx" is replaced by the summary. PROMPT end diff --git a/lib/personas/tools/create_artifact.rb b/lib/personas/tools/create_artifact.rb index c12577ff..548fb9aa 100644 --- a/lib/personas/tools/create_artifact.rb +++ b/lib/personas/tools/create_artifact.rb @@ -151,7 +151,12 @@ module DiscourseAi LlmModel.find_by(id: options[:creator_llm].to_i)&.to_llm ) || self.llm - llm.generate(prompt, user: user, feature_name: "create_artifact") do |partial_response| + llm.generate( + prompt, + user: user, + feature_name: "create_artifact", + cancel_manager: context.cancel_manager, + ) do |partial_response| response << partial_response yield partial_response end diff --git a/lib/personas/tools/create_image.rb b/lib/personas/tools/create_image.rb index f57c83ff..8e2971fa 100644 --- a/lib/personas/tools/create_image.rb +++ b/lib/personas/tools/create_image.rb @@ -48,6 +48,7 @@ module DiscourseAi max_prompts, model: "gpt-image-1", user_id: bot_user.id, + cancel_manager: context.cancel_manager, ) rescue => e @error = e diff --git a/lib/personas/tools/edit_image.rb b/lib/personas/tools/edit_image.rb index b28afd0c..b9e3249a 100644 --- a/lib/personas/tools/edit_image.rb +++ b/lib/personas/tools/edit_image.rb @@ -60,6 +60,7 @@ module DiscourseAi uploads, prompt, user_id: bot_user.id, + cancel_manager: context.cancel_manager, ) rescue => e @error = e diff --git a/lib/personas/tools/researcher.rb b/lib/personas/tools/researcher.rb new file mode 100644 index 00000000..3ab4c8e0 --- /dev/null +++ b/lib/personas/tools/researcher.rb @@ -0,0 +1,181 @@ +# frozen_string_literal: true + +module DiscourseAi + module Personas + module Tools + class Researcher < Tool + attr_reader :filter, :result_count, :goals, :dry_run + + class << self + def signature + { + name: name, + description: + "Analyze and extract information from content across the forum based on specified filters", + parameters: [ + { name: "filter", description: filter_description, type: "string" }, + { + name: "goals", + description: + "The specific information you want to extract or analyze from the filtered content, you may specify multiple goals", + type: "string", + }, + { + name: "dry_run", + description: "When true, only count matching items without processing data", + type: "boolean", + }, + ], + } + end + + def filter_description + <<~TEXT + Filter string to target specific content. + - Supports user (@username) + - date ranges (after:YYYY-MM-DD, before:YYYY-MM-DD for posts; topic_after:YYYY-MM-DD, topic_before:YYYY-MM-DD for topics) + - categories (category:category1,category2) + - tags (tag:tag1,tag2) + - groups (group:group1,group2). + - status (status:open, status:closed, status:archived, status:noreplies, status:single_user) + - keywords (keywords:keyword1,keyword2) - specific words to search for in posts + - max_results (max_results:10) the maximum number of results to return (optional) + - order (order:latest, order:oldest, order:latest_topic, order:oldest_topic) - the order of the results (optional) + + If multiple tags or categories are specified, they are treated as OR conditions. + + Multiple filters can be combined with spaces. Example: '@sam after:2023-01-01 tag:feature' + TEXT + end + + def name + "researcher" + end + + def accepted_options + [ + option(:max_results, type: :integer), + option(:include_private, type: :boolean), + option(:max_tokens_per_post, type: :integer), + ] + end + end + + def invoke(&blk) + max_results = options[:max_results] || 1000 + + @filter = parameters[:filter] || "" + @goals = parameters[:goals] || "" + @dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run] + + post = Post.find_by(id: context.post_id) + goals = parameters[:goals] || "" + dry_run = parameters[:dry_run].nil? ? false : parameters[:dry_run] + + return { error: "No goals provided" } if goals.blank? + return { error: "No filter provided" } if @filter.blank? + + guardian = nil + guardian = Guardian.new(context.user) if options[:include_private] + + filter = + DiscourseAi::Utils::Research::Filter.new( + @filter, + limit: max_results, + guardian: guardian, + ) + @result_count = filter.search.count + + blk.call details + + if dry_run + { dry_run: true, goals: goals, filter: @filter, number_of_results: @result_count } + else + process_filter(filter, goals, post, &blk) + end + end + + def details + if @dry_run + I18n.t("discourse_ai.ai_bot.tool_description.researcher_dry_run", description_args) + else + I18n.t("discourse_ai.ai_bot.tool_description.researcher", description_args) + end + end + + def description_args + { count: @result_count || 0, filter: @filter || "", goals: @goals || "" } + end + + protected + + MIN_TOKENS_FOR_RESEARCH = 8000 + def process_filter(filter, goals, post, &blk) + if llm.max_prompt_tokens < MIN_TOKENS_FOR_RESEARCH + raise ArgumentError, + "LLM max tokens too low for research. Minimum is #{MIN_TOKENS_FOR_RESEARCH}." + end + formatter = + DiscourseAi::Utils::Research::LlmFormatter.new( + filter, + max_tokens_per_batch: llm.max_prompt_tokens - 2000, + tokenizer: llm.tokenizer, + max_tokens_per_post: options[:max_tokens_per_post] || 2000, + ) + + results = [] + + formatter.each_chunk { |chunk| results << run_inference(chunk[:text], goals, post, &blk) } + { dry_run: false, goals: goals, filter: @filter, results: results } + end + + def run_inference(chunk_text, goals, post, &blk) + system_prompt = goal_system_prompt(goals) + user_prompt = goal_user_prompt(goals, chunk_text) + + prompt = + DiscourseAi::Completions::Prompt.new( + system_prompt, + messages: [{ type: :user, content: user_prompt }], + post_id: post.id, + topic_id: post.topic_id, + ) + + results = [] + llm.generate( + prompt, + user: post.user, + feature_name: context.feature_name, + cancel_manager: context.cancel_manager, + ) { |partial| results << partial } + + @progress_dots ||= 0 + @progress_dots += 1 + blk.call(details + "\n\n#{"." * @progress_dots}") + results.join + end + + def goal_system_prompt(goals) + <<~TEXT + You are a researcher tool designed to analyze and extract information from forum content. + Your task is to process the provided content and extract relevant information based on the specified goal. + + Your goal is: #{goals} + TEXT + end + + def goal_user_prompt(goals, chunk_text) + <<~TEXT + Here is the content to analyze: + + {{{ + #{chunk_text} + }}} + + Your goal is: #{goals} + TEXT + end + end + end + end +end diff --git a/lib/personas/tools/tool.rb b/lib/personas/tools/tool.rb index 4bafe0af..540204c2 100644 --- a/lib/personas/tools/tool.rb +++ b/lib/personas/tools/tool.rb @@ -47,8 +47,9 @@ module DiscourseAi end end - attr_accessor :custom_raw, :parameters - attr_reader :tool_call_id, :persona_options, :bot_user, :llm, :context + # llm being public makes it a bit easier to test + attr_accessor :custom_raw, :parameters, :llm + attr_reader :tool_call_id, :persona_options, :bot_user, :context def initialize( parameters, diff --git a/lib/personas/tools/update_artifact.rb b/lib/personas/tools/update_artifact.rb index ccf39a8d..46038d7a 100644 --- a/lib/personas/tools/update_artifact.rb +++ b/lib/personas/tools/update_artifact.rb @@ -159,6 +159,7 @@ module DiscourseAi artifact: artifact, artifact_version: artifact_version, instructions: instructions, + cancel_manager: context.cancel_manager, ) .apply do |progress| partial_response << progress diff --git a/lib/summarization/fold_content.rb b/lib/summarization/fold_content.rb index 15800087..8dbcda38 100644 --- a/lib/summarization/fold_content.rb +++ b/lib/summarization/fold_content.rb @@ -18,7 +18,7 @@ module DiscourseAi attr_reader :bot, :strategy # @param user { User } - User object used for auditing usage. - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # Note: The block is only called with results of the final summary, not intermediate summaries. # # This method doesn't care if we already have an up to date summary. It always regenerate. @@ -77,7 +77,7 @@ module DiscourseAi # @param items { Array } - Content to summarize. Structure will be: { poster: who wrote the content, id: a way to order content, text: content } # @param user { User } - User object used for auditing usage. - # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. + # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response. # Note: The block is only called with results of the final summary, not intermediate summaries. # # The summarization algorithm. @@ -112,7 +112,7 @@ module DiscourseAi summary = +"" buffer_blk = - Proc.new do |partial, cancel, _, type| + Proc.new do |partial, _, type| if type == :structured_output json_summary_schema_key = bot.persona.response_format&.first.to_h partial_summary = @@ -120,12 +120,12 @@ module DiscourseAi if partial_summary.present? summary << partial_summary - on_partial_blk.call(partial_summary, cancel) if on_partial_blk + on_partial_blk.call(partial_summary) if on_partial_blk end elsif type.blank? # Assume response is a regular completion. summary << partial - on_partial_blk.call(partial, cancel) if on_partial_blk + on_partial_blk.call(partial) if on_partial_blk end end diff --git a/lib/utils/research/filter.rb b/lib/utils/research/filter.rb new file mode 100644 index 00000000..734943e7 --- /dev/null +++ b/lib/utils/research/filter.rb @@ -0,0 +1,263 @@ +# frozen_string_literal: true + +module DiscourseAi + module Utils + module Research + class Filter + # Stores custom filter handlers + def self.register_filter(matcher, &block) + (@registered_filters ||= {})[matcher] = block + end + + def self.registered_filters + @registered_filters ||= {} + end + + def self.word_to_date(str) + ::Search.word_to_date(str) + end + + attr_reader :term, :filters, :order, :guardian, :limit, :offset + + # Define all filters at class level + register_filter(/\Astatus:open\z/i) do |relation, _, _| + relation.where("topics.closed = false AND topics.archived = false") + end + + register_filter(/\Astatus:closed\z/i) do |relation, _, _| + relation.where("topics.closed = true") + end + + register_filter(/\Astatus:archived\z/i) do |relation, _, _| + relation.where("topics.archived = true") + end + + register_filter(/\Astatus:noreplies\z/i) do |relation, _, _| + relation.where("topics.posts_count = 1") + end + + register_filter(/\Astatus:single_user\z/i) do |relation, _, _| + relation.where("topics.participant_count = 1") + end + + # Date filters + register_filter(/\Abefore:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("posts.created_at < ?", date) + else + relation + end + end + + register_filter(/\Aafter:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("posts.created_at > ?", date) + else + relation + end + end + + register_filter(/\Atopic_before:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("topics.created_at < ?", date) + else + relation + end + end + + register_filter(/\Atopic_after:(.*)\z/i) do |relation, date_str, _| + if date = Filter.word_to_date(date_str) + relation.where("topics.created_at > ?", date) + else + relation + end + end + + register_filter(/\A(?:tags?|tag):(.*)\z/i) do |relation, tag_param, _| + if tag_param.include?(",") + tag_names = tag_param.split(",").map(&:strip) + tag_ids = Tag.where(name: tag_names).pluck(:id) + return relation.where("1 = 0") if tag_ids.empty? + relation.where(topic_id: TopicTag.where(tag_id: tag_ids).select(:topic_id)) + else + if tag = Tag.find_by(name: tag_param) + relation.where(topic_id: TopicTag.where(tag_id: tag.id).select(:topic_id)) + else + relation.where("1 = 0") + end + end + end + + register_filter(/\Akeywords?:(.*)\z/i) do |relation, keywords_param, _| + if keywords_param.blank? + relation + else + keywords = keywords_param.split(",").map(&:strip).reject(&:blank?) + if keywords.empty? + relation + else + # Build a ts_query string joined by | (OR) + ts_query = keywords.map { |kw| kw.gsub(/['\\]/, " ") }.join(" | ") + relation = + relation.joins("JOIN post_search_data ON post_search_data.post_id = posts.id") + relation.where( + "post_search_data.search_data @@ to_tsquery(?, ?)", + ::Search.ts_config, + ts_query, + ) + end + end + end + + register_filter(/\A(?:categories?|category):(.*)\z/i) do |relation, category_param, _| + if category_param.include?(",") + category_names = category_param.split(",").map(&:strip) + + found_category_ids = [] + category_names.each do |name| + category = Category.find_by(slug: name) || Category.find_by(name: name) + found_category_ids << category.id if category + end + + return relation.where("1 = 0") if found_category_ids.empty? + relation.where(topic_id: Topic.where(category_id: found_category_ids).select(:id)) + else + if category = + Category.find_by(slug: category_param) || Category.find_by(name: category_param) + relation.where(topic_id: Topic.where(category_id: category.id).select(:id)) + else + relation.where("1 = 0") + end + end + end + + register_filter(/\A\@(\w+)\z/i) do |relation, username, filter| + user = User.find_by(username_lower: username.downcase) + if user + relation.where("posts.user_id = ?", user.id) + else + relation.where("1 = 0") # No results if user doesn't exist + end + end + + register_filter(/\Ain:posted\z/i) do |relation, _, filter| + if filter.guardian.user + relation.where("posts.user_id = ?", filter.guardian.user.id) + else + relation.where("1 = 0") # No results if not logged in + end + end + + register_filter(/\Agroup:([a-zA-Z0-9_\-]+)\z/i) do |relation, name, filter| + group = Group.find_by("name ILIKE ?", name) + if group + relation.where( + "posts.user_id IN ( + SELECT gu.user_id FROM group_users gu + WHERE gu.group_id = ? + )", + group.id, + ) + else + relation.where("1 = 0") # No results if group doesn't exist + end + end + + register_filter(/\Amax_results:(\d+)\z/i) do |relation, limit_str, filter| + filter.limit_by_user!(limit_str.to_i) + relation + end + + register_filter(/\Aorder:latest\z/i) do |relation, order_str, filter| + filter.set_order!(:latest_post) + relation + end + + register_filter(/\Aorder:oldest\z/i) do |relation, order_str, filter| + filter.set_order!(:oldest_post) + relation + end + + register_filter(/\Aorder:latest_topic\z/i) do |relation, order_str, filter| + filter.set_order!(:latest_topic) + relation + end + + register_filter(/\Aorder:oldest_topic\z/i) do |relation, order_str, filter| + filter.set_order!(:oldest_topic) + relation + end + + def initialize(term, guardian: nil, limit: nil, offset: nil) + @term = term.to_s + @guardian = guardian || Guardian.new + @limit = limit + @offset = offset + @filters = [] + @valid = true + @order = :latest_post + + @term = process_filters(@term) + end + + def set_order!(order) + @order = order + end + + def limit_by_user!(limit) + @limit = limit if limit.to_i < @limit.to_i || @limit.nil? + end + + def search + filtered = Post.secured(@guardian).joins(:topic).merge(Topic.secured(@guardian)) + + @filters.each do |filter_block, match_data| + filtered = filter_block.call(filtered, match_data, self) + end + + filtered = filtered.limit(@limit) if @limit.to_i > 0 + filtered = filtered.offset(@offset) if @offset.to_i > 0 + + if @order == :latest_post + filtered = filtered.order("posts.created_at DESC") + elsif @order == :oldest_post + filtered = filtered.order("posts.created_at ASC") + elsif @order == :latest_topic + filtered = filtered.order("topics.created_at DESC, posts.post_number DESC") + elsif @order == :oldest_topic + filtered = filtered.order("topics.created_at ASC, posts.post_number ASC") + end + + filtered + end + + private + + def process_filters(term) + return "" if term.blank? + + term + .to_s + .scan(/(([^" \t\n\x0B\f\r]+)?(("[^"]+")?))/) + .to_a + .map do |(word, _)| + next if word.blank? + + found = false + self.class.registered_filters.each do |matcher, block| + if word =~ matcher + @filters << [block, $1] + found = true + break + end + end + + found ? nil : word + end + .compact + .join(" ") + end + end + end + end +end diff --git a/lib/utils/research/llm_formatter.rb b/lib/utils/research/llm_formatter.rb new file mode 100644 index 00000000..a762f2dd --- /dev/null +++ b/lib/utils/research/llm_formatter.rb @@ -0,0 +1,205 @@ +# frozen_string_literal: true + +module DiscourseAi + module Utils + module Research + class LlmFormatter + def initialize(filter, max_tokens_per_batch:, tokenizer:, max_tokens_per_post:) + @filter = filter + @max_tokens_per_batch = max_tokens_per_batch + @tokenizer = tokenizer + @max_tokens_per_post = max_tokens_per_post + @to_process = filter_to_hash + end + + def each_chunk + return nil if @to_process.empty? + + result = { post_count: 0, topic_count: 0, text: +"" } + estimated_tokens = 0 + + @to_process.each do |topic_id, topic_data| + topic = Topic.find_by(id: topic_id) + next unless topic + + topic_text, topic_tokens, post_count = format_topic(topic, topic_data[:posts]) + + # If this single topic exceeds our token limit and we haven't added anything yet, + # we need to include at least this one topic (partial content) + if estimated_tokens == 0 && topic_tokens > @max_tokens_per_batch + offset = 0 + while offset < topic_text.length + chunk = +"" + chunk_tokens = 0 + lines = topic_text[offset..].lines + lines.each do |line| + line_tokens = estimate_tokens(line) + break if chunk_tokens + line_tokens > @max_tokens_per_batch + chunk << line + chunk_tokens += line_tokens + end + break if chunk.empty? + yield( + { + text: chunk, + post_count: post_count, # This may overcount if split mid-topic, but preserves original logic + topic_count: 1, + } + ) + offset += chunk.length + end + + next + end + + # If adding this topic would exceed our token limit and we already have content, skip it + if estimated_tokens > 0 && estimated_tokens + topic_tokens > @max_tokens_per_batch + yield result if result[:text].present? + estimated_tokens = 0 + result = { post_count: 0, topic_count: 0, text: +"" } + else + # Add this topic to the result + result[:text] << topic_text + result[:post_count] += post_count + result[:topic_count] += 1 + estimated_tokens += topic_tokens + end + end + yield result if result[:text].present? + + @to_process.clear + end + + private + + def filter_to_hash + hash = {} + @filter + .search + .pluck(:topic_id, :id, :post_number) + .each do |topic_id, post_id, post_number| + hash[topic_id] ||= { posts: [] } + hash[topic_id][:posts] << [post_id, post_number] + end + + hash.each_value { |topic| topic[:posts].sort_by! { |_, post_number| post_number } } + hash + end + + def format_topic(topic, posts_data) + text = "" + total_tokens = 0 + post_count = 0 + + # Add topic header + text += format_topic_header(topic) + total_tokens += estimate_tokens(text) + + # Get all post numbers in this topic + all_post_numbers = topic.posts.pluck(:post_number).sort + + # Format posts with omitted information + first_post_number = posts_data.first[1] + last_post_number = posts_data.last[1] + + # Handle posts before our selection + if first_post_number > 1 + omitted_before = first_post_number - 1 + text += format_omitted_posts(omitted_before, "before") + total_tokens += estimate_tokens(format_omitted_posts(omitted_before, "before")) + end + + # Format each post + posts_data.each do |post_id, post_number| + post = Post.find_by(id: post_id) + next unless post + + text += format_post(post) + total_tokens += estimate_tokens(format_post(post)) + post_count += 1 + end + + # Handle posts after our selection + if last_post_number < all_post_numbers.last + omitted_after = all_post_numbers.last - last_post_number + text += format_omitted_posts(omitted_after, "after") + total_tokens += estimate_tokens(format_omitted_posts(omitted_after, "after")) + end + + [text, total_tokens, post_count] + end + + def format_topic_header(topic) + header = +"# #{topic.title}\n" + + # Add category + header << "Category: #{topic.category.name}\n" if topic.category + + # Add tags + header << "Tags: #{topic.tags.map(&:name).join(", ")}\n" if topic.tags.present? + + # Add creation date + header << "Created: #{format_date(topic.created_at)}\n" + header << "Topic url: /t/#{topic.id}\n" + header << "Status: #{format_topic_status(topic)}\n\n" + + header + end + + def format_topic_status(topic) + solved = topic.respond_to?(:solved) && topic.solved.present? + solved_text = solved ? " (solved)" : "" + if topic.archived? + "Archived#{solved_text}" + elsif topic.closed? + "Closed#{solved_text}" + else + "Open#{solved_text}" + end + end + + def format_post(post) + text = +"---\n" + text << "## Post by #{post.user&.username} - #{format_date(post.created_at)}\n\n" + text << "#{truncate_if_needed(post.raw)}\n" + text << "Likes: #{post.like_count}\n" if post.like_count.to_i > 0 + text << "Post url: /t/-/#{post.topic_id}/#{post.post_number}\n\n" + text + end + + def truncate_if_needed(content) + tokens_count = estimate_tokens(content) + + return content if tokens_count <= @max_tokens_per_post + + half_limit = @max_tokens_per_post / 2 + token_ids = @tokenizer.encode(content) + + first_half_ids = token_ids[0...half_limit] + last_half_ids = token_ids[-half_limit..-1] + + first_text = @tokenizer.decode(first_half_ids) + last_text = @tokenizer.decode(last_half_ids) + + "#{first_text}\n\n... elided #{tokens_count - @max_tokens_per_post} tokens ...\n\n#{last_text}" + end + + def format_omitted_posts(count, position) + if position == "before" + "#{count} earlier #{count == 1 ? "post" : "posts"} omitted\n\n" + else + "#{count} later #{count == 1 ? "post" : "posts"} omitted\n\n" + end + end + + def format_date(date) + date.strftime("%Y-%m-%d %H:%M") + end + + def estimate_tokens(text) + @tokenizer.tokenize(text).length + end + end + end + end +end diff --git a/spec/lib/completions/cancel_manager_spec.rb b/spec/lib/completions/cancel_manager_spec.rb new file mode 100644 index 00000000..57961826 --- /dev/null +++ b/spec/lib/completions/cancel_manager_spec.rb @@ -0,0 +1,106 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::CancelManager do + fab!(:model) { Fabricate(:anthropic_model, name: "test-model") } + + it "can stop monitoring for cancellation cleanly" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + cancel_manager.start_monitor(delay: 100) { false } + expect(cancel_manager.monitor_thread).not_to be_nil + cancel_manager.stop_monitor + expect(cancel_manager.cancelled?).to eq(false) + expect(cancel_manager.monitor_thread).to be_nil + end + + it "can monitor for cancellation" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + results = [true, false, false] + + cancel_manager.start_monitor(delay: 0) { results.pop } + + wait_for { cancel_manager.cancelled? == true } + wait_for { cancel_manager.monitor_thread.nil? } + + expect(cancel_manager.cancelled?).to eq(true) + expect(cancel_manager.monitor_thread).to be_nil + end + + it "should do nothing when cancel manager is already cancelled" do + cancel_manager = DiscourseAi::Completions::CancelManager.new + cancel_manager.cancel! + + llm = model.to_llm + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a test bot", + messages: [{ type: :user, content: "hello" }], + ) + + result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager) + expect(result).to be_nil + end + + it "should be able to cancel a completion" do + # Start an HTTP server that hangs indefinitely + server = TCPServer.new("127.0.0.1", 0) + port = server.addr[1] + + begin + thread = + Thread.new do + loop do + begin + _client = server.accept + sleep(30) # Hold the connection longer than the test will run + break + rescue StandardError + # Server closed + break + end + end + end + + # Create a model that points to our hanging server + model.update!(url: "http://127.0.0.1:#{port}") + + cancel_manager = DiscourseAi::Completions::CancelManager.new + + completion_thread = + Thread.new do + llm = model.to_llm + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a test bot", + messages: [{ type: :user, content: "hello" }], + ) + + result = llm.generate(prompt, user: Discourse.system_user, cancel_manager: cancel_manager) + expect(result).to be_nil + expect(cancel_manager.cancelled).to eq(true) + end + + wait_for { cancel_manager.callbacks.size == 1 } + + cancel_manager.cancel! + completion_thread.join(2) + + expect(completion_thread).not_to be_alive + ensure + begin + server.close + rescue StandardError + nil + end + begin + thread.kill + rescue StandardError + nil + end + begin + completion_thread&.kill + rescue StandardError + nil + end + end + end +end diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 130c735b..8ad95b95 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -188,9 +188,11 @@ class EndpointsCompliance mock.stub_streamed_simple_call(dialect.translate) do completion_response = +"" - endpoint.perform_completion!(dialect, user) do |partial, cancel| + cancel_manager = DiscourseAi::Completions::CancelManager.new + + endpoint.perform_completion!(dialect, user, cancel_manager: cancel_manager) do |partial| completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + cancel_manager.cancel! if completion_response.split(" ").length == 2 end expect(AiApiAuditLog.count).to eq(1) @@ -212,12 +214,14 @@ class EndpointsCompliance prompt = generic_prompt(tools: [mock.tool]) a_dialect = dialect(prompt: prompt) + cancel_manager = DiscourseAi::Completions::CancelManager.new + mock.stub_streamed_tool_call(a_dialect.translate) do buffered_partial = [] - endpoint.perform_completion!(a_dialect, user) do |partial, cancel| + endpoint.perform_completion!(a_dialect, user, cancel_manager: cancel_manager) do |partial| buffered_partial << partial - cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall) + cancel_manager if partial.is_a?(DiscourseAi::Completions::ToolCall) end expect(buffered_partial).to eq([mock.invocation_response]) diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 4d21c94d..e9ad0ac3 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -1136,14 +1136,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do split = body.split("|") + cancel_manager = DiscourseAi::Completions::CancelManager.new + count = 0 DiscourseAi::AiBot::PostStreamer.on_callback = proc do |callback| count += 1 - if count == 2 - last_post = third_post.topic.posts.order(:id).last - Discourse.redis.del("gpt_cancel:#{last_post.id}") - end + cancel_manager.cancel! if count == 2 raise "this should not happen" if count > 2 end @@ -1155,13 +1154,13 @@ RSpec.describe DiscourseAi::AiBot::Playground do ) # we are going to need to use real data here cause we want to trigger the # base endpoint to cancel part way through - playground.reply_to(third_post) + playground.reply_to(third_post, cancel_manager: cancel_manager) end last_post = third_post.topic.posts.order(:id).last - # not Hello123, we cancelled at 1 which means we may get 2 and then be done - expect(last_post.raw).to eq("Hello12") + # not Hello123, we cancelled at 1 + expect(last_post.raw).to eq("Hello1") end end diff --git a/spec/lib/personas/persona_spec.rb b/spec/lib/personas/persona_spec.rb index e3c96e34..fa81876e 100644 --- a/spec/lib/personas/persona_spec.rb +++ b/spec/lib/personas/persona_spec.rb @@ -218,32 +218,28 @@ RSpec.describe DiscourseAi::Personas::Persona do SiteSetting.ai_google_custom_search_cx = "abc123" # should be ordered by priority and then alpha - expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to eq( - [ - DiscourseAi::Personas::General, - DiscourseAi::Personas::Artist, - DiscourseAi::Personas::Creative, - DiscourseAi::Personas::DiscourseHelper, - DiscourseAi::Personas::GithubHelper, - DiscourseAi::Personas::Researcher, - DiscourseAi::Personas::SettingsExplorer, - DiscourseAi::Personas::SqlHelper, - ], + expect(DiscourseAi::Personas::Persona.all(user: user).map(&:superclass)).to contain_exactly( + DiscourseAi::Personas::General, + DiscourseAi::Personas::Artist, + DiscourseAi::Personas::Creative, + DiscourseAi::Personas::DiscourseHelper, + DiscourseAi::Personas::GithubHelper, + DiscourseAi::Personas::Researcher, + DiscourseAi::Personas::SettingsExplorer, + DiscourseAi::Personas::SqlHelper, ) # it should allow staff access to WebArtifactCreator - expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to eq( - [ - DiscourseAi::Personas::General, - DiscourseAi::Personas::Artist, - DiscourseAi::Personas::Creative, - DiscourseAi::Personas::DiscourseHelper, - DiscourseAi::Personas::GithubHelper, - DiscourseAi::Personas::Researcher, - DiscourseAi::Personas::SettingsExplorer, - DiscourseAi::Personas::SqlHelper, - DiscourseAi::Personas::WebArtifactCreator, - ], + expect(DiscourseAi::Personas::Persona.all(user: admin).map(&:superclass)).to contain_exactly( + DiscourseAi::Personas::General, + DiscourseAi::Personas::Artist, + DiscourseAi::Personas::Creative, + DiscourseAi::Personas::DiscourseHelper, + DiscourseAi::Personas::GithubHelper, + DiscourseAi::Personas::Researcher, + DiscourseAi::Personas::SettingsExplorer, + DiscourseAi::Personas::SqlHelper, + DiscourseAi::Personas::WebArtifactCreator, ) # omits personas if key is missing diff --git a/spec/lib/personas/tools/researcher_spec.rb b/spec/lib/personas/tools/researcher_spec.rb new file mode 100644 index 00000000..51227001 --- /dev/null +++ b/spec/lib/personas/tools/researcher_spec.rb @@ -0,0 +1,109 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Personas::Tools::Researcher do + before { SearchIndexer.enable } + after { SearchIndexer.disable } + + fab!(:llm_model) + let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } + let(:progress_blk) { Proc.new {} } + + fab!(:admin) + fab!(:user) + fab!(:category) { Fabricate(:category, name: "research-category") } + fab!(:tag_research) { Fabricate(:tag, name: "research") } + fab!(:tag_data) { Fabricate(:tag, name: "data") } + + fab!(:topic_with_tags) { Fabricate(:topic, category: category, tags: [tag_research, tag_data]) } + fab!(:post) { Fabricate(:post, topic: topic_with_tags) } + + before { SiteSetting.ai_bot_enabled = true } + + describe "#invoke" do + it "returns filter information and result count" do + researcher = + described_class.new( + { filter: "tag:research after:2023", goals: "analyze post patterns", dry_run: true }, + bot_user: bot_user, + llm: llm, + context: DiscourseAi::Personas::BotContext.new(user: user, post: post), + ) + + results = researcher.invoke(&progress_blk) + + expect(results[:filter]).to eq("tag:research after:2023") + expect(results[:goals]).to eq("analyze post patterns") + expect(results[:dry_run]).to eq(true) + expect(results[:number_of_results]).to be > 0 + expect(researcher.filter).to eq("tag:research after:2023") + expect(researcher.result_count).to be > 0 + end + + it "handles empty filters" do + researcher = + described_class.new({ goals: "analyze all content" }, bot_user: bot_user, llm: llm) + + results = researcher.invoke(&progress_blk) + + expect(results[:error]).to eq("No filter provided") + end + + it "accepts max_results option" do + researcher = + described_class.new( + { filter: "category:research-category" }, + persona_options: { + "max_results" => "50", + }, + bot_user: bot_user, + llm: llm, + ) + + expect(researcher.options[:max_results]).to eq(50) + end + + it "returns correct results for non-dry-run with filtered posts" do + # Stage 2 topics, each with 2 posts + topics = Array.new(2) { Fabricate(:topic, category: category, tags: [tag_research]) } + topics.flat_map do |topic| + [ + Fabricate(:post, topic: topic, raw: "Relevant content 1", user: user), + Fabricate(:post, topic: topic, raw: "Relevant content 2", user: admin), + ] + end + + # Filter to posts by user in research-category + researcher = + described_class.new( + { + filter: "category:research-category @#{user.username}", + goals: "find relevant content", + dry_run: false, + }, + bot_user: bot_user, + llm: llm, + context: DiscourseAi::Personas::BotContext.new(user: user, post: post), + ) + + responses = 10.times.map { |i| ["Found: Relevant content #{i + 1}"] } + results = nil + + last_progress = nil + progress_blk = Proc.new { |response| last_progress = response } + + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do + researcher.llm = llm_model.to_llm + results = researcher.invoke(&progress_blk) + end + + expect(last_progress).to include("find relevant content") + expect(last_progress).to include("category:research-category") + + expect(results[:dry_run]).to eq(false) + expect(results[:goals]).to eq("find relevant content") + expect(results[:filter]).to eq("category:research-category @#{user.username}") + expect(results[:results].first).to include("Found: Relevant content 1") + end + end +end diff --git a/spec/lib/utils/research/filter_spec.rb b/spec/lib/utils/research/filter_spec.rb new file mode 100644 index 00000000..655d1705 --- /dev/null +++ b/spec/lib/utils/research/filter_spec.rb @@ -0,0 +1,142 @@ +# frozen_string_literal: true + +describe DiscourseAi::Utils::Research::Filter do + describe "integration tests" do + before_all { SiteSetting.min_topic_title_length = 3 } + + fab!(:user) + + fab!(:feature_tag) { Fabricate(:tag, name: "feature") } + fab!(:bug_tag) { Fabricate(:tag, name: "bug") } + + fab!(:announcement_category) { Fabricate(:category, name: "Announcements") } + fab!(:feedback_category) { Fabricate(:category, name: "Feedback") } + + fab!(:feature_topic) do + Fabricate( + :topic, + user: user, + tags: [feature_tag], + category: announcement_category, + title: "New Feature Discussion", + ) + end + + fab!(:bug_topic) do + Fabricate( + :topic, + tags: [bug_tag], + user: user, + category: announcement_category, + title: "Bug Report", + ) + end + + fab!(:feature_bug_topic) do + Fabricate( + :topic, + tags: [feature_tag, bug_tag], + user: user, + category: feedback_category, + title: "Feature with Bug", + ) + end + + fab!(:no_tag_topic) do + Fabricate(:topic, user: user, category: feedback_category, title: "General Discussion") + end + + fab!(:feature_post) { Fabricate(:post, topic: feature_topic, user: user) } + fab!(:bug_post) { Fabricate(:post, topic: bug_topic, user: user) } + fab!(:feature_bug_post) { Fabricate(:post, topic: feature_bug_topic, user: user) } + fab!(:no_tag_post) { Fabricate(:post, topic: no_tag_topic, user: user) } + + describe "tag filtering" do + it "correctly filters posts by tags" do + filter = described_class.new("tag:feature") + expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, feature_bug_post.id) + + filter = described_class.new("tag:feature,bug") + expect(filter.search.pluck(:id)).to contain_exactly( + feature_bug_post.id, + bug_post.id, + feature_post.id, + ) + + filter = described_class.new("tags:bug") + expect(filter.search.pluck(:id)).to contain_exactly(bug_post.id, feature_bug_post.id) + + filter = described_class.new("tag:nonexistent") + expect(filter.search.count).to eq(0) + end + end + + describe "category filtering" do + it "correctly filters posts by categories" do + filter = described_class.new("category:Announcements") + expect(filter.search.pluck(:id)).to contain_exactly(feature_post.id, bug_post.id) + + filter = described_class.new("category:Announcements,Feedback") + expect(filter.search.pluck(:id)).to contain_exactly( + feature_post.id, + bug_post.id, + feature_bug_post.id, + no_tag_post.id, + ) + + filter = described_class.new("categories:Feedback") + expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id, no_tag_post.id) + + filter = described_class.new("category:Feedback tag:feature") + expect(filter.search.pluck(:id)).to contain_exactly(feature_bug_post.id) + end + end + + it "can limit number of results" do + filter = described_class.new("category:Feedback max_results:1", limit: 5) + expect(filter.search.pluck(:id).length).to eq(1) + end + + describe "full text keyword searching" do + before_all { SearchIndexer.enable } + fab!(:post_with_apples) do + Fabricate(:post, raw: "This post contains apples", topic: feature_topic, user: user) + end + + fab!(:post_with_bananas) do + Fabricate(:post, raw: "This post mentions bananas", topic: bug_topic, user: user) + end + + fab!(:post_with_both) do + Fabricate( + :post, + raw: "This post has apples and bananas", + topic: feature_bug_topic, + user: user, + ) + end + + fab!(:post_with_none) do + Fabricate(:post, raw: "No fruits here", topic: no_tag_topic, user: user) + end + + it "correctly filters posts by full text keywords" do + filter = described_class.new("keywords:apples") + expect(filter.search.pluck(:id)).to contain_exactly(post_with_apples.id, post_with_both.id) + + filter = described_class.new("keywords:bananas") + expect(filter.search.pluck(:id)).to contain_exactly(post_with_bananas.id, post_with_both.id) + + filter = described_class.new("keywords:apples,bananas") + expect(filter.search.pluck(:id)).to contain_exactly( + post_with_apples.id, + post_with_bananas.id, + post_with_both.id, + ) + + filter = described_class.new("keywords:oranges") + expect(filter.search.count).to eq(0) + end + end + end +end diff --git a/spec/lib/utils/research/llm_formatter_spec.rb b/spec/lib/utils/research/llm_formatter_spec.rb new file mode 100644 index 00000000..edc10363 --- /dev/null +++ b/spec/lib/utils/research/llm_formatter_spec.rb @@ -0,0 +1,74 @@ +# frozen_string_literal: true +# +describe DiscourseAi::Utils::Research::LlmFormatter do + fab!(:user) { Fabricate(:user, username: "test_user") } + fab!(:topic) { Fabricate(:topic, title: "This is a Test Topic", user: user) } + fab!(:post) { Fabricate(:post, topic: topic, user: user) } + let(:tokenizer) { DiscourseAi::Tokenizer::OpenAiTokenizer } + let(:filter) { DiscourseAi::Utils::Research::Filter.new("@#{user.username}") } + + describe "#truncate_if_needed" do + it "returns original content when under token limit" do + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 100, + ) + + short_text = "This is a short post" + expect(formatter.send(:truncate_if_needed, short_text)).to eq(short_text) + end + + it "truncates content when over token limit" do + # Create a post with content that will exceed our token limit + long_text = ("word " * 200).strip + + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 50, + ) + + truncated = formatter.send(:truncate_if_needed, long_text) + + expect(truncated).to include("... elided 150 tokens ...") + expect(truncated).to_not eq(long_text) + + # Should have roughly 25 words before and 25 after (half of max_tokens_per_post) + first_chunk = truncated.split("\n\n")[0] + expect(first_chunk.split(" ").length).to be_within(5).of(25) + + last_chunk = truncated.split("\n\n")[2] + expect(last_chunk.split(" ").length).to be_within(5).of(25) + end + end + + describe "#format_post" do + it "formats posts with truncation for long content" do + # Set up a post with long content + long_content = ("word " * 200).strip + long_post = Fabricate(:post, raw: long_content, topic: topic, user: user) + + formatter = + described_class.new( + filter, + max_tokens_per_batch: 1000, + tokenizer: tokenizer, + max_tokens_per_post: 50, + ) + + formatted = formatter.send(:format_post, long_post) + + # Should have standard formatting elements + expect(formatted).to include("## Post by #{user.username}") + expect(formatted).to include("Post url: /t/-/#{long_post.topic_id}/#{long_post.post_number}") + + # Should include truncation marker + expect(formatted).to include("... elided 150 tokens ...") + end + end +end