diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 8a3b327f..1ce5d3ec 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -162,7 +162,7 @@ module DiscourseAi end end - def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil) + def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil, add_user_to_pm: false) ai_persona = AiPersona.find_by(id: persona_id) raise Discourse::InvalidParameters.new(:persona_id) if !ai_persona persona_class = ai_persona.class_instance @@ -173,7 +173,12 @@ module DiscourseAi bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona) playground = DiscourseAi::AiBot::Playground.new(bot) - playground.reply_to(post, whisper: whisper, context_style: :topic) + playground.reply_to( + post, + whisper: whisper, + context_style: :topic, + add_user_to_pm: add_user_to_pm, + ) end def initialize(bot) @@ -433,7 +438,14 @@ module DiscourseAi result end - def reply_to(post, custom_instructions: nil, whisper: nil, context_style: nil, &blk) + def reply_to( + post, + custom_instructions: nil, + whisper: nil, + context_style: nil, + add_user_to_pm: true, + &blk + ) # this is a multithreading issue # post custom prompt is needed and it may not # be properly loaded, ensure it is loaded @@ -470,7 +482,7 @@ module DiscourseAi stream_reply = post.topic.private_message? # we need to ensure persona user is allowed to reply to the pm - if post.topic.private_message? + if post.topic.private_message? && add_user_to_pm if !post.topic.topic_allowed_users.exists?(user_id: reply_user.id) post.topic.topic_allowed_users.create!(user_id: reply_user.id) end @@ -485,6 +497,7 @@ module DiscourseAi skip_validations: true, skip_jobs: true, post_type: post_type, + skip_guardian: true, ) publish_update(reply_post, { raw: reply_post.cooked }) @@ -560,6 +573,7 @@ module DiscourseAi raw: reply, skip_validations: true, post_type: post_type, + skip_guardian: true, ) end diff --git a/lib/ai_bot/tool_runner.rb b/lib/ai_bot/tool_runner.rb index 2a3899ef..14af53e2 100644 --- a/lib/ai_bot/tool_runner.rb +++ b/lib/ai_bot/tool_runner.rb @@ -73,6 +73,9 @@ module DiscourseAi }; const discourse = { + search: function(params) { + return _discourse_search(params); + }, getPost: _discourse_get_post, getUser: _discourse_get_user, getPersona: function(name) { @@ -341,6 +344,21 @@ module DiscourseAi end end, ) + + mini_racer_context.attach( + "_discourse_search", + ->(params) do + in_attached_function do + search_params = params.symbolize_keys + if search_params.delete(:with_private) + search_params[:current_user] = Discourse.system_user + end + search_params[:result_style] = :detailed + results = DiscourseAi::Utils::Search.perform_search(**search_params) + recursive_as_json(results) + end + end, + ) end def attach_upload(mini_racer_context) diff --git a/lib/ai_bot/tools/search.rb b/lib/ai_bot/tools/search.rb index 5ddf78c3..a14e77d6 100644 --- a/lib/ai_bot/tools/search.rb +++ b/lib/ai_bot/tools/search.rb @@ -34,7 +34,7 @@ module DiscourseAi enum: %w[latest latest_topic oldest views likes], }, { - name: "limit", + name: "max_results", description: "limit number of results returned (generally prefer to just keep to default)", type: "integer", @@ -103,102 +103,38 @@ module DiscourseAi def invoke search_terms = [] - search_terms << options[:base_query] if options[:base_query].present? - search_terms << search_query.strip if search_query.present? + search_terms << search_query if search_query.present? search_args.each { |key, value| search_terms << "#{key}:#{value}" if value.present? } - guardian = nil - if options[:search_private] && context[:user] - guardian = Guardian.new(context[:user]) - else - guardian = Guardian.new - search_terms << "status:public" - end + @last_query = search_terms.join(" ").to_s - search_string = search_terms.join(" ").to_s - @last_query = search_string - - yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string)) - - results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian) + yield(I18n.t("discourse_ai.ai_bot.searching", query: @last_query)) max_results = calculate_max_results(llm) - results_limit = parameters[:limit] || max_results - results_limit = max_results if parameters[:limit].to_i > max_results - - should_try_semantic_search = - SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present? - - max_semantic_results = max_results / 4 - results_limit = results_limit - max_semantic_results if should_try_semantic_search - - posts = results&.posts || [] - posts = posts[0..results_limit.to_i - 1] - - if should_try_semantic_search - semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian) - topic_ids = Set.new(posts.map(&:topic_id)) - - search = ::Search.new(search_string, guardian: guardian) - - results = nil - begin - results = semantic_search.search_for_topics(search.term) - rescue => e - Discourse.warn_exception(e, message: "Semantic search failed") - end - - if results - results = search.apply_filters(results) - - results.each do |post| - next if topic_ids.include?(post.topic_id) - - topic_ids << post.topic_id - posts << post - - break if posts.length >= max_results - end - end + if parameters[:max_results].to_i > 0 + max_results = [parameters[:max_results].to_i, max_results].min end - @last_num_results = posts.length - # this is the general pattern from core - # if there are millions of hidden tags it may fail - hidden_tags = nil + search_query_with_base = [options[:base_query], search_query].compact.join(" ").strip - if posts.blank? - { args: parameters, rows: [], instruction: "nothing was found, expand your search" } - else - format_results(posts, args: parameters) do |post| - category_names = [ - post.topic.category&.parent_category&.name, - post.topic.category&.name, - ].compact.join(" > ") - row = { - title: post.topic.title, - url: Discourse.base_path + post.url, - username: post.user&.username, - excerpt: post.excerpt, - created: post.created_at, - category: category_names, - likes: post.like_count, - topic_views: post.topic.views, - topic_likes: post.topic.like_count, - topic_replies: post.topic.posts_count - 1, - } + results = + DiscourseAi::Utils::Search.perform_search( + search_query: search_query_with_base, + category: parameters[:category], + user: parameters[:user], + order: parameters[:order], + max_posts: parameters[:max_posts], + tags: parameters[:tags], + before: parameters[:before], + after: parameters[:after], + status: parameters[:status], + max_results: max_results, + current_user: options[:search_private] ? context[:user] : nil, + ) - if SiteSetting.tagging_enabled - hidden_tags ||= DiscourseTagging.hidden_tag_names - # using map over pluck to avoid n+1 (assuming caller preloading) - tags = post.topic.tags.map(&:name) - hidden_tags - row[:tags] = tags.join(", ") if tags.present? - end - - row - end - end + @last_num_results = results[:rows]&.length || 0 + results end protected diff --git a/lib/utils/search.rb b/lib/utils/search.rb new file mode 100644 index 00000000..6e124b5a --- /dev/null +++ b/lib/utils/search.rb @@ -0,0 +1,151 @@ +# frozen_string_literal: true + +module DiscourseAi + module Utils + class Search + def self.perform_search( + search_query: nil, + category: nil, + user: nil, + order: nil, + max_posts: nil, + tags: nil, + before: nil, + after: nil, + status: nil, + hyde: true, + max_results: 20, + current_user: nil, + result_style: :compact + ) + search_terms = [] + + search_terms << search_query.strip if search_query.present? + search_terms << "category:#{category}" if category.present? + search_terms << "user:#{user}" if user.present? + search_terms << "order:#{order}" if order.present? + search_terms << "max_posts:#{max_posts}" if max_posts.present? + search_terms << "tags:#{tags}" if tags.present? + search_terms << "before:#{before}" if before.present? + search_terms << "after:#{after}" if after.present? + search_terms << "status:#{status}" if status.present? + + guardian = Guardian.new(current_user) + + search_string = search_terms.join(" ").to_s + + results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian) + results_limit = max_results + + should_try_semantic_search = + SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present? + + max_semantic_results = max_results / 4 + results_limit = results_limit - max_semantic_results if should_try_semantic_search + + posts = results&.posts || [] + posts = posts[0..results_limit.to_i - 1] + + if should_try_semantic_search + semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian) + topic_ids = Set.new(posts.map(&:topic_id)) + + search = ::Search.new(search_string, guardian: guardian) + + semantic_results = nil + begin + semantic_results = semantic_search.search_for_topics(search.term, hyde: hyde) + rescue => e + Discourse.warn_exception(e, message: "Semantic search failed") + end + + if semantic_results + semantic_results = search.apply_filters(semantic_results) + + semantic_results.each do |post| + next if topic_ids.include?(post.topic_id) + + topic_ids << post.topic_id + posts << post + + break if posts.length >= max_results + end + end + end + + hidden_tags = nil + + # Construct search_args hash for consistent return format + search_args = { + search_query: search_query, + category: category, + user: user, + order: order, + max_posts: max_posts, + tags: tags, + before: before, + after: after, + status: status, + max_results: max_results, + }.compact + + if posts.blank? + { args: search_args, rows: [], instruction: "nothing was found, expand your search" } + else + format_results(posts, args: search_args, result_style: result_style) do |post| + category_names = [ + post.topic.category&.parent_category&.name, + post.topic.category&.name, + ].compact.join(" > ") + row = { + title: post.topic.title, + url: Discourse.base_path + post.url, + username: post.user&.username, + excerpt: post.excerpt, + created: post.created_at, + category: category_names, + likes: post.like_count, + topic_views: post.topic.views, + topic_likes: post.topic.like_count, + topic_replies: post.topic.posts_count - 1, + } + + if SiteSetting.tagging_enabled + hidden_tags ||= DiscourseTagging.hidden_tag_names + tags = post.topic.tags.map(&:name) - hidden_tags + row[:tags] = tags.join(", ") if tags.present? + end + + row + end + end + end + + private + + def self.format_results(rows, args: nil, result_style:) + rows = rows&.map { |row| yield row } if block_given? + + if result_style == :compact + index = -1 + column_indexes = {} + + rows = + rows&.map do |data| + new_row = [] + data.each do |key, value| + found_index = column_indexes[key.to_s] ||= (index += 1) + new_row[found_index] = value + end + new_row + end + column_names = column_indexes.keys + end + + result = { column_names: column_names, rows: rows } + result[:args] = args if args + result + end + end + end +end diff --git a/spec/lib/discourse_automation/llm_persona_triage_spec.rb b/spec/lib/discourse_automation/llm_persona_triage_spec.rb index 37e40193..9e1f2a9b 100644 --- a/spec/lib/discourse_automation/llm_persona_triage_spec.rb +++ b/spec/lib/discourse_automation/llm_persona_triage_spec.rb @@ -165,7 +165,7 @@ describe DiscourseAi::Automation::LlmPersonaTriage do expect(context).to include("support") end - it "passes private message metadata in context when responding to PM" do + it "interacts correctly with PMs" do # Create a private message topic pm_topic = Fabricate(:private_message_topic, user: user, title: "Important PM") @@ -190,6 +190,8 @@ describe DiscourseAi::Automation::LlmPersonaTriage do # Capture the prompt sent to the LLM prompt = nil + original_user_ids = pm_topic.topic_allowed_users.pluck(:user_id) + DiscourseAi::Completions::Llm.with_prepared_responses( ["I've received your private message"], ) do |_, _, _prompts| @@ -204,5 +206,13 @@ describe DiscourseAi::Automation::LlmPersonaTriage do expect(context).to include("Important PM") expect(context).to include(pm_post.raw) expect(context).to include(pm_post2.raw) + + reply = pm_topic.posts.order(:post_number).last + expect(reply.raw).to eq("I've received your private message") + + topic = reply.topic + + # should not inject persona into allowed users + expect(topic.topic_allowed_users.pluck(:user_id).sort).to eq(original_user_ids.sort) end end diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 976d9c42..94dba3c7 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -98,7 +98,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do results = search.invoke(&progress_blk) - expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" }) + expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake", max_results: 60 }) expect(results[:rows]).to eq([]) end @@ -131,7 +131,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do search.invoke(&progress_blk) end - expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" }) + expect(results[:args]).to eq( + { max_results: 60, search_query: "hello world, sam", status: "public" }, + ) expect(results[:rows].length).to eq(1) # it also works with no query @@ -174,6 +176,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do [param[:name], "test"] end end + .compact .to_h .symbolize_keys diff --git a/spec/lib/utils/search_spec.rb b/spec/lib/utils/search_spec.rb new file mode 100644 index 00000000..d8dfdbf9 --- /dev/null +++ b/spec/lib/utils/search_spec.rb @@ -0,0 +1,198 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Utils::Search do + before { SearchIndexer.enable } + after { SearchIndexer.disable } + + fab!(:admin) + fab!(:user) + fab!(:group) + fab!(:parent_category) { Fabricate(:category, name: "animals") } + fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") } + fab!(:tag_funny) { Fabricate(:tag, name: "funny") } + fab!(:tag_sad) { Fabricate(:tag, name: "sad") } + fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") } + fab!(:staff_tag_group) do + tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"]) + + tag_group.permissions = [ + [Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]], + ] + tag_group.save! + tag_group + end + + fab!(:topic_with_tags) do + Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden]) + end + + fab!(:private_category) do + c = Fabricate(:category_with_definition) + c.set_permissions(group => :readonly) + c.save + c + end + + describe ".perform_search" do + it "returns search results with correct format" do + post = Fabricate(:post, topic: topic_with_tags) + + results = + described_class.perform_search( + search_query: post.raw, + user: post.user.username, + current_user: admin, + ) + + expect(results).to have_key(:args) + expect(results).to have_key(:rows) + expect(results).to have_key(:column_names) + expect(results[:rows].length).to eq(1) + end + + it "handles no results" do + results = + described_class.perform_search( + search_query: "NONEXISTENTTERMNOONEWOULDSEARCH", + current_user: admin, + ) + + expect(results[:rows]).to eq([]) + expect(results[:instruction]).to eq("nothing was found, expand your search") + end + + it "returns private results when user has access" do + private_post = Fabricate(:post, topic: Fabricate(:topic, category: private_category)) + + # Regular user without access + results = described_class.perform_search(search_query: private_post.raw, current_user: user) + expect(results[:rows].length).to eq(0) + + # Add user to group with access + GroupUser.create!(group: group, user: user) + + # Now should find the private post + results = described_class.perform_search(search_query: private_post.raw, current_user: user) + expect(results[:rows].length).to eq(1) + end + + it "properly handles subfolder URLs" do + Discourse.stubs(:base_path).returns("/subfolder") + + post = Fabricate(:post, topic: topic_with_tags) + + results = described_class.perform_search(search_query: post.raw, current_user: admin) + + url_index = results[:column_names].index("url") + expect(results[:rows][0][url_index]).to include("/subfolder") + end + + it "returns rich topic information" do + post = Fabricate(:post, like_count: 1, topic: topic_with_tags) + post.topic.update!(views: 100, posts_count: 2, like_count: 10) + + results = described_class.perform_search(search_query: post.raw, current_user: admin) + + row = results[:rows].first + + category_index = results[:column_names].index("category") + expect(row[category_index]).to eq("animals > amazing-cat") + + tags_index = results[:column_names].index("tags") + expect(row[tags_index]).to eq("funny, sad") + + likes_index = results[:column_names].index("likes") + expect(row[likes_index]).to eq(1) + + topic_likes_index = results[:column_names].index("topic_likes") + expect(row[topic_likes_index]).to eq(10) + + topic_views_index = results[:column_names].index("topic_views") + expect(row[topic_views_index]).to eq(100) + + topic_replies_index = results[:column_names].index("topic_replies") + expect(row[topic_replies_index]).to eq(1) + end + + context "when using semantic search" do + let(:query) { "this is an expanded search" } + after do + if defined?(DiscourseAi::Embeddings::SemanticSearch) + DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) + end + end + + it "includes semantic search results when enabled" do + assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) + vector_def = Fabricate(:embedding_definition) + SiteSetting.ai_embeddings_selected_model = vector_def.id + SiteSetting.ai_embeddings_semantic_search_enabled = true + + hyde_embedding = [0.049382] * vector_def.dimensions + EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding) + + post = Fabricate(:post, topic: topic_with_tags) + DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, hyde_embedding, "digest") + + # Using a completely different search query, should still find via semantic search + results = + DiscourseAi::Completions::Llm.with_prepared_responses(["#{query}"]) do + described_class.perform_search( + search_query: "totally different query", + current_user: admin, + ) + end + + expect(results[:rows].length).to eq(1) + end + + it "can disable semantic search with hyde parameter" do + assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) + vector_def = Fabricate(:embedding_definition) + SiteSetting.ai_embeddings_selected_model = vector_def.id + SiteSetting.ai_embeddings_semantic_search_enabled = true + + embedding = [0.049382] * vector_def.dimensions + EmbeddingsGenerationStubs.hugging_face_service(query, embedding) + + post = Fabricate(:post, topic: topic_with_tags) + DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, embedding, "digest") + + WebMock + .stub_request(:post, "https://test.com/embeddings") + .with(body: "{\"inputs\":\"totally different query\",\"truncate\":true}") + .to_return(status: 200, body: embedding.to_json) + + results = + described_class.perform_search( + search_query: "totally different query", + hyde: false, + current_user: admin, + ) + + expect(results[:rows].length).to eq(0) + end + end + + it "passes all search parameters to the results args" do + post = Fabricate(:post, topic: topic_with_tags) + + search_params = { + search_query: post.raw, + category: category.name, + user: post.user.username, + order: "latest", + max_posts: 10, + tags: tag_funny.name, + before: "2030-01-01", + after: "2000-01-01", + status: "public", + max_results: 15, + } + + results = described_class.perform_search(**search_params, current_user: admin) + + expect(results[:args]).to include(search_params) + end + end +end diff --git a/spec/models/ai_tool_spec.rb b/spec/models/ai_tool_spec.rb index 66a826a9..282a57b5 100644 --- a/spec/models/ai_tool_spec.rb +++ b/spec/models/ai_tool_spec.rb @@ -328,4 +328,37 @@ RSpec.describe AiTool do expect(result).to eq(expected) end end + + context "when using the search API" do + before { SearchIndexer.enable } + after { SearchIndexer.disable } + + it "can perform a discourse search" do + # Create a new topic + topic = Fabricate(:topic, title: "Test Search Topic", category: Fabricate(:category)) + post = Fabricate(:post, topic: topic, raw: "This is a test post content, banana") + + # Ensure the topic is indexed + SearchIndexer.index(topic, force: true) + SearchIndexer.index(post, force: true) + + # Define the tool script + script = <<~JS + function invoke(params) { + return discourse.search({ search_query: params.query }); + } + JS + + # Create the tool and runner + tool = create_tool(script: script) + runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {}) + + # Invoke the tool and get the results + result = runner.invoke + + # Verify the topic is found + expect(result["rows"].length).to be > 0 + expect(result["rows"].first["title"]).to eq("Test Search Topic") + end + end end