FIX: AI triage support and refactor search functionality (#1175)

* FIX: do not add bot user to PM when using responders

* Allow AI tool to call search directly

* remove stray p
This commit is contained in:
Sam 2025-03-11 14:26:07 +11:00 committed by GitHub
parent 3da4f5eac3
commit 168d9d8eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 457 additions and 94 deletions

View File

@ -162,7 +162,7 @@ module DiscourseAi
end end
end end
def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil) def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil, add_user_to_pm: false)
ai_persona = AiPersona.find_by(id: persona_id) ai_persona = AiPersona.find_by(id: persona_id)
raise Discourse::InvalidParameters.new(:persona_id) if !ai_persona raise Discourse::InvalidParameters.new(:persona_id) if !ai_persona
persona_class = ai_persona.class_instance persona_class = ai_persona.class_instance
@ -173,7 +173,12 @@ module DiscourseAi
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona) bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona)
playground = DiscourseAi::AiBot::Playground.new(bot) playground = DiscourseAi::AiBot::Playground.new(bot)
playground.reply_to(post, whisper: whisper, context_style: :topic) playground.reply_to(
post,
whisper: whisper,
context_style: :topic,
add_user_to_pm: add_user_to_pm,
)
end end
def initialize(bot) def initialize(bot)
@ -433,7 +438,14 @@ module DiscourseAi
result result
end end
def reply_to(post, custom_instructions: nil, whisper: nil, context_style: nil, &blk) def reply_to(
post,
custom_instructions: nil,
whisper: nil,
context_style: nil,
add_user_to_pm: true,
&blk
)
# this is a multithreading issue # this is a multithreading issue
# post custom prompt is needed and it may not # post custom prompt is needed and it may not
# be properly loaded, ensure it is loaded # be properly loaded, ensure it is loaded
@ -470,7 +482,7 @@ module DiscourseAi
stream_reply = post.topic.private_message? stream_reply = post.topic.private_message?
# we need to ensure persona user is allowed to reply to the pm # we need to ensure persona user is allowed to reply to the pm
if post.topic.private_message? if post.topic.private_message? && add_user_to_pm
if !post.topic.topic_allowed_users.exists?(user_id: reply_user.id) if !post.topic.topic_allowed_users.exists?(user_id: reply_user.id)
post.topic.topic_allowed_users.create!(user_id: reply_user.id) post.topic.topic_allowed_users.create!(user_id: reply_user.id)
end end
@ -485,6 +497,7 @@ module DiscourseAi
skip_validations: true, skip_validations: true,
skip_jobs: true, skip_jobs: true,
post_type: post_type, post_type: post_type,
skip_guardian: true,
) )
publish_update(reply_post, { raw: reply_post.cooked }) publish_update(reply_post, { raw: reply_post.cooked })
@ -560,6 +573,7 @@ module DiscourseAi
raw: reply, raw: reply,
skip_validations: true, skip_validations: true,
post_type: post_type, post_type: post_type,
skip_guardian: true,
) )
end end

View File

@ -73,6 +73,9 @@ module DiscourseAi
}; };
const discourse = { const discourse = {
search: function(params) {
return _discourse_search(params);
},
getPost: _discourse_get_post, getPost: _discourse_get_post,
getUser: _discourse_get_user, getUser: _discourse_get_user,
getPersona: function(name) { getPersona: function(name) {
@ -341,6 +344,21 @@ module DiscourseAi
end end
end, end,
) )
mini_racer_context.attach(
"_discourse_search",
->(params) do
in_attached_function do
search_params = params.symbolize_keys
if search_params.delete(:with_private)
search_params[:current_user] = Discourse.system_user
end
search_params[:result_style] = :detailed
results = DiscourseAi::Utils::Search.perform_search(**search_params)
recursive_as_json(results)
end
end,
)
end end
def attach_upload(mini_racer_context) def attach_upload(mini_racer_context)

View File

@ -34,7 +34,7 @@ module DiscourseAi
enum: %w[latest latest_topic oldest views likes], enum: %w[latest latest_topic oldest views likes],
}, },
{ {
name: "limit", name: "max_results",
description: description:
"limit number of results returned (generally prefer to just keep to default)", "limit number of results returned (generally prefer to just keep to default)",
type: "integer", type: "integer",
@ -103,102 +103,38 @@ module DiscourseAi
def invoke def invoke
search_terms = [] search_terms = []
search_terms << options[:base_query] if options[:base_query].present? search_terms << options[:base_query] if options[:base_query].present?
search_terms << search_query.strip if search_query.present? search_terms << search_query if search_query.present?
search_args.each { |key, value| search_terms << "#{key}:#{value}" if value.present? } search_args.each { |key, value| search_terms << "#{key}:#{value}" if value.present? }
guardian = nil @last_query = search_terms.join(" ").to_s
if options[:search_private] && context[:user]
guardian = Guardian.new(context[:user])
else
guardian = Guardian.new
search_terms << "status:public"
end
search_string = search_terms.join(" ").to_s yield(I18n.t("discourse_ai.ai_bot.searching", query: @last_query))
@last_query = search_string
yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
max_results = calculate_max_results(llm) max_results = calculate_max_results(llm)
results_limit = parameters[:limit] || max_results if parameters[:max_results].to_i > 0
results_limit = max_results if parameters[:limit].to_i > max_results max_results = [parameters[:max_results].to_i, max_results].min
should_try_semantic_search =
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
max_semantic_results = max_results / 4
results_limit = results_limit - max_semantic_results if should_try_semantic_search
posts = results&.posts || []
posts = posts[0..results_limit.to_i - 1]
if should_try_semantic_search
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
topic_ids = Set.new(posts.map(&:topic_id))
search = ::Search.new(search_string, guardian: guardian)
results = nil
begin
results = semantic_search.search_for_topics(search.term)
rescue => e
Discourse.warn_exception(e, message: "Semantic search failed")
end
if results
results = search.apply_filters(results)
results.each do |post|
next if topic_ids.include?(post.topic_id)
topic_ids << post.topic_id
posts << post
break if posts.length >= max_results
end
end
end end
@last_num_results = posts.length search_query_with_base = [options[:base_query], search_query].compact.join(" ").strip
# this is the general pattern from core
# if there are millions of hidden tags it may fail
hidden_tags = nil
if posts.blank? results =
{ args: parameters, rows: [], instruction: "nothing was found, expand your search" } DiscourseAi::Utils::Search.perform_search(
else search_query: search_query_with_base,
format_results(posts, args: parameters) do |post| category: parameters[:category],
category_names = [ user: parameters[:user],
post.topic.category&.parent_category&.name, order: parameters[:order],
post.topic.category&.name, max_posts: parameters[:max_posts],
].compact.join(" > ") tags: parameters[:tags],
row = { before: parameters[:before],
title: post.topic.title, after: parameters[:after],
url: Discourse.base_path + post.url, status: parameters[:status],
username: post.user&.username, max_results: max_results,
excerpt: post.excerpt, current_user: options[:search_private] ? context[:user] : nil,
created: post.created_at, )
category: category_names,
likes: post.like_count,
topic_views: post.topic.views,
topic_likes: post.topic.like_count,
topic_replies: post.topic.posts_count - 1,
}
if SiteSetting.tagging_enabled @last_num_results = results[:rows]&.length || 0
hidden_tags ||= DiscourseTagging.hidden_tag_names results
# using map over pluck to avoid n+1 (assuming caller preloading)
tags = post.topic.tags.map(&:name) - hidden_tags
row[:tags] = tags.join(", ") if tags.present?
end
row
end
end
end end
protected protected

151
lib/utils/search.rb Normal file
View File

@ -0,0 +1,151 @@
# frozen_string_literal: true
module DiscourseAi
module Utils
class Search
def self.perform_search(
search_query: nil,
category: nil,
user: nil,
order: nil,
max_posts: nil,
tags: nil,
before: nil,
after: nil,
status: nil,
hyde: true,
max_results: 20,
current_user: nil,
result_style: :compact
)
search_terms = []
search_terms << search_query.strip if search_query.present?
search_terms << "category:#{category}" if category.present?
search_terms << "user:#{user}" if user.present?
search_terms << "order:#{order}" if order.present?
search_terms << "max_posts:#{max_posts}" if max_posts.present?
search_terms << "tags:#{tags}" if tags.present?
search_terms << "before:#{before}" if before.present?
search_terms << "after:#{after}" if after.present?
search_terms << "status:#{status}" if status.present?
guardian = Guardian.new(current_user)
search_string = search_terms.join(" ").to_s
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
results_limit = max_results
should_try_semantic_search =
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
max_semantic_results = max_results / 4
results_limit = results_limit - max_semantic_results if should_try_semantic_search
posts = results&.posts || []
posts = posts[0..results_limit.to_i - 1]
if should_try_semantic_search
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
topic_ids = Set.new(posts.map(&:topic_id))
search = ::Search.new(search_string, guardian: guardian)
semantic_results = nil
begin
semantic_results = semantic_search.search_for_topics(search.term, hyde: hyde)
rescue => e
Discourse.warn_exception(e, message: "Semantic search failed")
end
if semantic_results
semantic_results = search.apply_filters(semantic_results)
semantic_results.each do |post|
next if topic_ids.include?(post.topic_id)
topic_ids << post.topic_id
posts << post
break if posts.length >= max_results
end
end
end
hidden_tags = nil
# Construct search_args hash for consistent return format
search_args = {
search_query: search_query,
category: category,
user: user,
order: order,
max_posts: max_posts,
tags: tags,
before: before,
after: after,
status: status,
max_results: max_results,
}.compact
if posts.blank?
{ args: search_args, rows: [], instruction: "nothing was found, expand your search" }
else
format_results(posts, args: search_args, result_style: result_style) do |post|
category_names = [
post.topic.category&.parent_category&.name,
post.topic.category&.name,
].compact.join(" > ")
row = {
title: post.topic.title,
url: Discourse.base_path + post.url,
username: post.user&.username,
excerpt: post.excerpt,
created: post.created_at,
category: category_names,
likes: post.like_count,
topic_views: post.topic.views,
topic_likes: post.topic.like_count,
topic_replies: post.topic.posts_count - 1,
}
if SiteSetting.tagging_enabled
hidden_tags ||= DiscourseTagging.hidden_tag_names
tags = post.topic.tags.map(&:name) - hidden_tags
row[:tags] = tags.join(", ") if tags.present?
end
row
end
end
end
private
def self.format_results(rows, args: nil, result_style:)
rows = rows&.map { |row| yield row } if block_given?
if result_style == :compact
index = -1
column_indexes = {}
rows =
rows&.map do |data|
new_row = []
data.each do |key, value|
found_index = column_indexes[key.to_s] ||= (index += 1)
new_row[found_index] = value
end
new_row
end
column_names = column_indexes.keys
end
result = { column_names: column_names, rows: rows }
result[:args] = args if args
result
end
end
end
end

View File

@ -165,7 +165,7 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
expect(context).to include("support") expect(context).to include("support")
end end
it "passes private message metadata in context when responding to PM" do it "interacts correctly with PMs" do
# Create a private message topic # Create a private message topic
pm_topic = Fabricate(:private_message_topic, user: user, title: "Important PM") pm_topic = Fabricate(:private_message_topic, user: user, title: "Important PM")
@ -190,6 +190,8 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
# Capture the prompt sent to the LLM # Capture the prompt sent to the LLM
prompt = nil prompt = nil
original_user_ids = pm_topic.topic_allowed_users.pluck(:user_id)
DiscourseAi::Completions::Llm.with_prepared_responses( DiscourseAi::Completions::Llm.with_prepared_responses(
["I've received your private message"], ["I've received your private message"],
) do |_, _, _prompts| ) do |_, _, _prompts|
@ -204,5 +206,13 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
expect(context).to include("Important PM") expect(context).to include("Important PM")
expect(context).to include(pm_post.raw) expect(context).to include(pm_post.raw)
expect(context).to include(pm_post2.raw) expect(context).to include(pm_post2.raw)
reply = pm_topic.posts.order(:post_number).last
expect(reply.raw).to eq("I've received your private message")
topic = reply.topic
# should not inject persona into allowed users
expect(topic.topic_allowed_users.pluck(:user_id).sort).to eq(original_user_ids.sort)
end end
end end

View File

@ -98,7 +98,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
results = search.invoke(&progress_blk) results = search.invoke(&progress_blk)
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" }) expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake", max_results: 60 })
expect(results[:rows]).to eq([]) expect(results[:rows]).to eq([])
end end
@ -131,7 +131,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
search.invoke(&progress_blk) search.invoke(&progress_blk)
end end
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" }) expect(results[:args]).to eq(
{ max_results: 60, search_query: "hello world, sam", status: "public" },
)
expect(results[:rows].length).to eq(1) expect(results[:rows].length).to eq(1)
# it also works with no query # it also works with no query
@ -174,6 +176,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
[param[:name], "test"] [param[:name], "test"]
end end
end end
.compact
.to_h .to_h
.symbolize_keys .symbolize_keys

View File

@ -0,0 +1,198 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Utils::Search do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
fab!(:admin)
fab!(:user)
fab!(:group)
fab!(:parent_category) { Fabricate(:category, name: "animals") }
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
fab!(:staff_tag_group) do
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
tag_group.permissions = [
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
]
tag_group.save!
tag_group
end
fab!(:topic_with_tags) do
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
end
fab!(:private_category) do
c = Fabricate(:category_with_definition)
c.set_permissions(group => :readonly)
c.save
c
end
describe ".perform_search" do
it "returns search results with correct format" do
post = Fabricate(:post, topic: topic_with_tags)
results =
described_class.perform_search(
search_query: post.raw,
user: post.user.username,
current_user: admin,
)
expect(results).to have_key(:args)
expect(results).to have_key(:rows)
expect(results).to have_key(:column_names)
expect(results[:rows].length).to eq(1)
end
it "handles no results" do
results =
described_class.perform_search(
search_query: "NONEXISTENTTERMNOONEWOULDSEARCH",
current_user: admin,
)
expect(results[:rows]).to eq([])
expect(results[:instruction]).to eq("nothing was found, expand your search")
end
it "returns private results when user has access" do
private_post = Fabricate(:post, topic: Fabricate(:topic, category: private_category))
# Regular user without access
results = described_class.perform_search(search_query: private_post.raw, current_user: user)
expect(results[:rows].length).to eq(0)
# Add user to group with access
GroupUser.create!(group: group, user: user)
# Now should find the private post
results = described_class.perform_search(search_query: private_post.raw, current_user: user)
expect(results[:rows].length).to eq(1)
end
it "properly handles subfolder URLs" do
Discourse.stubs(:base_path).returns("/subfolder")
post = Fabricate(:post, topic: topic_with_tags)
results = described_class.perform_search(search_query: post.raw, current_user: admin)
url_index = results[:column_names].index("url")
expect(results[:rows][0][url_index]).to include("/subfolder")
end
it "returns rich topic information" do
post = Fabricate(:post, like_count: 1, topic: topic_with_tags)
post.topic.update!(views: 100, posts_count: 2, like_count: 10)
results = described_class.perform_search(search_query: post.raw, current_user: admin)
row = results[:rows].first
category_index = results[:column_names].index("category")
expect(row[category_index]).to eq("animals > amazing-cat")
tags_index = results[:column_names].index("tags")
expect(row[tags_index]).to eq("funny, sad")
likes_index = results[:column_names].index("likes")
expect(row[likes_index]).to eq(1)
topic_likes_index = results[:column_names].index("topic_likes")
expect(row[topic_likes_index]).to eq(10)
topic_views_index = results[:column_names].index("topic_views")
expect(row[topic_views_index]).to eq(100)
topic_replies_index = results[:column_names].index("topic_replies")
expect(row[topic_replies_index]).to eq(1)
end
context "when using semantic search" do
let(:query) { "this is an expanded search" }
after do
if defined?(DiscourseAi::Embeddings::SemanticSearch)
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query)
end
end
it "includes semantic search results when enabled" do
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
vector_def = Fabricate(:embedding_definition)
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_semantic_search_enabled = true
hyde_embedding = [0.049382] * vector_def.dimensions
EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding)
post = Fabricate(:post, topic: topic_with_tags)
DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, hyde_embedding, "digest")
# Using a completely different search query, should still find via semantic search
results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
described_class.perform_search(
search_query: "totally different query",
current_user: admin,
)
end
expect(results[:rows].length).to eq(1)
end
it "can disable semantic search with hyde parameter" do
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
vector_def = Fabricate(:embedding_definition)
SiteSetting.ai_embeddings_selected_model = vector_def.id
SiteSetting.ai_embeddings_semantic_search_enabled = true
embedding = [0.049382] * vector_def.dimensions
EmbeddingsGenerationStubs.hugging_face_service(query, embedding)
post = Fabricate(:post, topic: topic_with_tags)
DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, embedding, "digest")
WebMock
.stub_request(:post, "https://test.com/embeddings")
.with(body: "{\"inputs\":\"totally different query\",\"truncate\":true}")
.to_return(status: 200, body: embedding.to_json)
results =
described_class.perform_search(
search_query: "totally different query",
hyde: false,
current_user: admin,
)
expect(results[:rows].length).to eq(0)
end
end
it "passes all search parameters to the results args" do
post = Fabricate(:post, topic: topic_with_tags)
search_params = {
search_query: post.raw,
category: category.name,
user: post.user.username,
order: "latest",
max_posts: 10,
tags: tag_funny.name,
before: "2030-01-01",
after: "2000-01-01",
status: "public",
max_results: 15,
}
results = described_class.perform_search(**search_params, current_user: admin)
expect(results[:args]).to include(search_params)
end
end
end

View File

@ -328,4 +328,37 @@ RSpec.describe AiTool do
expect(result).to eq(expected) expect(result).to eq(expected)
end end
end end
context "when using the search API" do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
it "can perform a discourse search" do
# Create a new topic
topic = Fabricate(:topic, title: "Test Search Topic", category: Fabricate(:category))
post = Fabricate(:post, topic: topic, raw: "This is a test post content, banana")
# Ensure the topic is indexed
SearchIndexer.index(topic, force: true)
SearchIndexer.index(post, force: true)
# Define the tool script
script = <<~JS
function invoke(params) {
return discourse.search({ search_query: params.query });
}
JS
# Create the tool and runner
tool = create_tool(script: script)
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {})
# Invoke the tool and get the results
result = runner.invoke
# Verify the topic is found
expect(result["rows"].length).to be > 0
expect(result["rows"].first["title"]).to eq("Test Search Topic")
end
end
end end