mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-10 08:03:28 +00:00
FIX: AI triage support and refactor search functionality (#1175)
* FIX: do not add bot user to PM when using responders * Allow AI tool to call search directly * remove stray p
This commit is contained in:
parent
3da4f5eac3
commit
168d9d8eb9
@ -162,7 +162,7 @@ module DiscourseAi
|
||||
end
|
||||
end
|
||||
|
||||
def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil)
|
||||
def self.reply_to_post(post:, user: nil, persona_id: nil, whisper: nil, add_user_to_pm: false)
|
||||
ai_persona = AiPersona.find_by(id: persona_id)
|
||||
raise Discourse::InvalidParameters.new(:persona_id) if !ai_persona
|
||||
persona_class = ai_persona.class_instance
|
||||
@ -173,7 +173,12 @@ module DiscourseAi
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
playground.reply_to(post, whisper: whisper, context_style: :topic)
|
||||
playground.reply_to(
|
||||
post,
|
||||
whisper: whisper,
|
||||
context_style: :topic,
|
||||
add_user_to_pm: add_user_to_pm,
|
||||
)
|
||||
end
|
||||
|
||||
def initialize(bot)
|
||||
@ -433,7 +438,14 @@ module DiscourseAi
|
||||
result
|
||||
end
|
||||
|
||||
def reply_to(post, custom_instructions: nil, whisper: nil, context_style: nil, &blk)
|
||||
def reply_to(
|
||||
post,
|
||||
custom_instructions: nil,
|
||||
whisper: nil,
|
||||
context_style: nil,
|
||||
add_user_to_pm: true,
|
||||
&blk
|
||||
)
|
||||
# this is a multithreading issue
|
||||
# post custom prompt is needed and it may not
|
||||
# be properly loaded, ensure it is loaded
|
||||
@ -470,7 +482,7 @@ module DiscourseAi
|
||||
stream_reply = post.topic.private_message?
|
||||
|
||||
# we need to ensure persona user is allowed to reply to the pm
|
||||
if post.topic.private_message?
|
||||
if post.topic.private_message? && add_user_to_pm
|
||||
if !post.topic.topic_allowed_users.exists?(user_id: reply_user.id)
|
||||
post.topic.topic_allowed_users.create!(user_id: reply_user.id)
|
||||
end
|
||||
@ -485,6 +497,7 @@ module DiscourseAi
|
||||
skip_validations: true,
|
||||
skip_jobs: true,
|
||||
post_type: post_type,
|
||||
skip_guardian: true,
|
||||
)
|
||||
|
||||
publish_update(reply_post, { raw: reply_post.cooked })
|
||||
@ -560,6 +573,7 @@ module DiscourseAi
|
||||
raw: reply,
|
||||
skip_validations: true,
|
||||
post_type: post_type,
|
||||
skip_guardian: true,
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -73,6 +73,9 @@ module DiscourseAi
|
||||
};
|
||||
|
||||
const discourse = {
|
||||
search: function(params) {
|
||||
return _discourse_search(params);
|
||||
},
|
||||
getPost: _discourse_get_post,
|
||||
getUser: _discourse_get_user,
|
||||
getPersona: function(name) {
|
||||
@ -341,6 +344,21 @@ module DiscourseAi
|
||||
end
|
||||
end,
|
||||
)
|
||||
|
||||
mini_racer_context.attach(
|
||||
"_discourse_search",
|
||||
->(params) do
|
||||
in_attached_function do
|
||||
search_params = params.symbolize_keys
|
||||
if search_params.delete(:with_private)
|
||||
search_params[:current_user] = Discourse.system_user
|
||||
end
|
||||
search_params[:result_style] = :detailed
|
||||
results = DiscourseAi::Utils::Search.perform_search(**search_params)
|
||||
recursive_as_json(results)
|
||||
end
|
||||
end,
|
||||
)
|
||||
end
|
||||
|
||||
def attach_upload(mini_racer_context)
|
||||
|
@ -34,7 +34,7 @@ module DiscourseAi
|
||||
enum: %w[latest latest_topic oldest views likes],
|
||||
},
|
||||
{
|
||||
name: "limit",
|
||||
name: "max_results",
|
||||
description:
|
||||
"limit number of results returned (generally prefer to just keep to default)",
|
||||
type: "integer",
|
||||
@ -103,102 +103,38 @@ module DiscourseAi
|
||||
|
||||
def invoke
|
||||
search_terms = []
|
||||
|
||||
search_terms << options[:base_query] if options[:base_query].present?
|
||||
search_terms << search_query.strip if search_query.present?
|
||||
search_terms << search_query if search_query.present?
|
||||
search_args.each { |key, value| search_terms << "#{key}:#{value}" if value.present? }
|
||||
|
||||
guardian = nil
|
||||
if options[:search_private] && context[:user]
|
||||
guardian = Guardian.new(context[:user])
|
||||
else
|
||||
guardian = Guardian.new
|
||||
search_terms << "status:public"
|
||||
end
|
||||
@last_query = search_terms.join(" ").to_s
|
||||
|
||||
search_string = search_terms.join(" ").to_s
|
||||
@last_query = search_string
|
||||
|
||||
yield(I18n.t("discourse_ai.ai_bot.searching", query: search_string))
|
||||
|
||||
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
|
||||
yield(I18n.t("discourse_ai.ai_bot.searching", query: @last_query))
|
||||
|
||||
max_results = calculate_max_results(llm)
|
||||
results_limit = parameters[:limit] || max_results
|
||||
results_limit = max_results if parameters[:limit].to_i > max_results
|
||||
|
||||
should_try_semantic_search =
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
|
||||
|
||||
max_semantic_results = max_results / 4
|
||||
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
||||
|
||||
posts = results&.posts || []
|
||||
posts = posts[0..results_limit.to_i - 1]
|
||||
|
||||
if should_try_semantic_search
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
|
||||
topic_ids = Set.new(posts.map(&:topic_id))
|
||||
|
||||
search = ::Search.new(search_string, guardian: guardian)
|
||||
|
||||
results = nil
|
||||
begin
|
||||
results = semantic_search.search_for_topics(search.term)
|
||||
rescue => e
|
||||
Discourse.warn_exception(e, message: "Semantic search failed")
|
||||
if parameters[:max_results].to_i > 0
|
||||
max_results = [parameters[:max_results].to_i, max_results].min
|
||||
end
|
||||
|
||||
if results
|
||||
results = search.apply_filters(results)
|
||||
search_query_with_base = [options[:base_query], search_query].compact.join(" ").strip
|
||||
|
||||
results.each do |post|
|
||||
next if topic_ids.include?(post.topic_id)
|
||||
results =
|
||||
DiscourseAi::Utils::Search.perform_search(
|
||||
search_query: search_query_with_base,
|
||||
category: parameters[:category],
|
||||
user: parameters[:user],
|
||||
order: parameters[:order],
|
||||
max_posts: parameters[:max_posts],
|
||||
tags: parameters[:tags],
|
||||
before: parameters[:before],
|
||||
after: parameters[:after],
|
||||
status: parameters[:status],
|
||||
max_results: max_results,
|
||||
current_user: options[:search_private] ? context[:user] : nil,
|
||||
)
|
||||
|
||||
topic_ids << post.topic_id
|
||||
posts << post
|
||||
|
||||
break if posts.length >= max_results
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@last_num_results = posts.length
|
||||
# this is the general pattern from core
|
||||
# if there are millions of hidden tags it may fail
|
||||
hidden_tags = nil
|
||||
|
||||
if posts.blank?
|
||||
{ args: parameters, rows: [], instruction: "nothing was found, expand your search" }
|
||||
else
|
||||
format_results(posts, args: parameters) do |post|
|
||||
category_names = [
|
||||
post.topic.category&.parent_category&.name,
|
||||
post.topic.category&.name,
|
||||
].compact.join(" > ")
|
||||
row = {
|
||||
title: post.topic.title,
|
||||
url: Discourse.base_path + post.url,
|
||||
username: post.user&.username,
|
||||
excerpt: post.excerpt,
|
||||
created: post.created_at,
|
||||
category: category_names,
|
||||
likes: post.like_count,
|
||||
topic_views: post.topic.views,
|
||||
topic_likes: post.topic.like_count,
|
||||
topic_replies: post.topic.posts_count - 1,
|
||||
}
|
||||
|
||||
if SiteSetting.tagging_enabled
|
||||
hidden_tags ||= DiscourseTagging.hidden_tag_names
|
||||
# using map over pluck to avoid n+1 (assuming caller preloading)
|
||||
tags = post.topic.tags.map(&:name) - hidden_tags
|
||||
row[:tags] = tags.join(", ") if tags.present?
|
||||
end
|
||||
|
||||
row
|
||||
end
|
||||
end
|
||||
@last_num_results = results[:rows]&.length || 0
|
||||
results
|
||||
end
|
||||
|
||||
protected
|
||||
|
151
lib/utils/search.rb
Normal file
151
lib/utils/search.rb
Normal file
@ -0,0 +1,151 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Utils
|
||||
class Search
|
||||
def self.perform_search(
|
||||
search_query: nil,
|
||||
category: nil,
|
||||
user: nil,
|
||||
order: nil,
|
||||
max_posts: nil,
|
||||
tags: nil,
|
||||
before: nil,
|
||||
after: nil,
|
||||
status: nil,
|
||||
hyde: true,
|
||||
max_results: 20,
|
||||
current_user: nil,
|
||||
result_style: :compact
|
||||
)
|
||||
search_terms = []
|
||||
|
||||
search_terms << search_query.strip if search_query.present?
|
||||
search_terms << "category:#{category}" if category.present?
|
||||
search_terms << "user:#{user}" if user.present?
|
||||
search_terms << "order:#{order}" if order.present?
|
||||
search_terms << "max_posts:#{max_posts}" if max_posts.present?
|
||||
search_terms << "tags:#{tags}" if tags.present?
|
||||
search_terms << "before:#{before}" if before.present?
|
||||
search_terms << "after:#{after}" if after.present?
|
||||
search_terms << "status:#{status}" if status.present?
|
||||
|
||||
guardian = Guardian.new(current_user)
|
||||
|
||||
search_string = search_terms.join(" ").to_s
|
||||
|
||||
results = ::Search.execute(search_string, search_type: :full_page, guardian: guardian)
|
||||
results_limit = max_results
|
||||
|
||||
should_try_semantic_search =
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled && search_query.present?
|
||||
|
||||
max_semantic_results = max_results / 4
|
||||
results_limit = results_limit - max_semantic_results if should_try_semantic_search
|
||||
|
||||
posts = results&.posts || []
|
||||
posts = posts[0..results_limit.to_i - 1]
|
||||
|
||||
if should_try_semantic_search
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
|
||||
topic_ids = Set.new(posts.map(&:topic_id))
|
||||
|
||||
search = ::Search.new(search_string, guardian: guardian)
|
||||
|
||||
semantic_results = nil
|
||||
begin
|
||||
semantic_results = semantic_search.search_for_topics(search.term, hyde: hyde)
|
||||
rescue => e
|
||||
Discourse.warn_exception(e, message: "Semantic search failed")
|
||||
end
|
||||
|
||||
if semantic_results
|
||||
semantic_results = search.apply_filters(semantic_results)
|
||||
|
||||
semantic_results.each do |post|
|
||||
next if topic_ids.include?(post.topic_id)
|
||||
|
||||
topic_ids << post.topic_id
|
||||
posts << post
|
||||
|
||||
break if posts.length >= max_results
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
hidden_tags = nil
|
||||
|
||||
# Construct search_args hash for consistent return format
|
||||
search_args = {
|
||||
search_query: search_query,
|
||||
category: category,
|
||||
user: user,
|
||||
order: order,
|
||||
max_posts: max_posts,
|
||||
tags: tags,
|
||||
before: before,
|
||||
after: after,
|
||||
status: status,
|
||||
max_results: max_results,
|
||||
}.compact
|
||||
|
||||
if posts.blank?
|
||||
{ args: search_args, rows: [], instruction: "nothing was found, expand your search" }
|
||||
else
|
||||
format_results(posts, args: search_args, result_style: result_style) do |post|
|
||||
category_names = [
|
||||
post.topic.category&.parent_category&.name,
|
||||
post.topic.category&.name,
|
||||
].compact.join(" > ")
|
||||
row = {
|
||||
title: post.topic.title,
|
||||
url: Discourse.base_path + post.url,
|
||||
username: post.user&.username,
|
||||
excerpt: post.excerpt,
|
||||
created: post.created_at,
|
||||
category: category_names,
|
||||
likes: post.like_count,
|
||||
topic_views: post.topic.views,
|
||||
topic_likes: post.topic.like_count,
|
||||
topic_replies: post.topic.posts_count - 1,
|
||||
}
|
||||
|
||||
if SiteSetting.tagging_enabled
|
||||
hidden_tags ||= DiscourseTagging.hidden_tag_names
|
||||
tags = post.topic.tags.map(&:name) - hidden_tags
|
||||
row[:tags] = tags.join(", ") if tags.present?
|
||||
end
|
||||
|
||||
row
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def self.format_results(rows, args: nil, result_style:)
|
||||
rows = rows&.map { |row| yield row } if block_given?
|
||||
|
||||
if result_style == :compact
|
||||
index = -1
|
||||
column_indexes = {}
|
||||
|
||||
rows =
|
||||
rows&.map do |data|
|
||||
new_row = []
|
||||
data.each do |key, value|
|
||||
found_index = column_indexes[key.to_s] ||= (index += 1)
|
||||
new_row[found_index] = value
|
||||
end
|
||||
new_row
|
||||
end
|
||||
column_names = column_indexes.keys
|
||||
end
|
||||
|
||||
result = { column_names: column_names, rows: rows }
|
||||
result[:args] = args if args
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -165,7 +165,7 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
|
||||
expect(context).to include("support")
|
||||
end
|
||||
|
||||
it "passes private message metadata in context when responding to PM" do
|
||||
it "interacts correctly with PMs" do
|
||||
# Create a private message topic
|
||||
pm_topic = Fabricate(:private_message_topic, user: user, title: "Important PM")
|
||||
|
||||
@ -190,6 +190,8 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
|
||||
# Capture the prompt sent to the LLM
|
||||
prompt = nil
|
||||
|
||||
original_user_ids = pm_topic.topic_allowed_users.pluck(:user_id)
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
["I've received your private message"],
|
||||
) do |_, _, _prompts|
|
||||
@ -204,5 +206,13 @@ describe DiscourseAi::Automation::LlmPersonaTriage do
|
||||
expect(context).to include("Important PM")
|
||||
expect(context).to include(pm_post.raw)
|
||||
expect(context).to include(pm_post2.raw)
|
||||
|
||||
reply = pm_topic.posts.order(:post_number).last
|
||||
expect(reply.raw).to eq("I've received your private message")
|
||||
|
||||
topic = reply.topic
|
||||
|
||||
# should not inject persona into allowed users
|
||||
expect(topic.topic_allowed_users.pluck(:user_id).sort).to eq(original_user_ids.sort)
|
||||
end
|
||||
end
|
||||
|
@ -98,7 +98,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
|
||||
results = search.invoke(&progress_blk)
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake" })
|
||||
expect(results[:args]).to eq({ search_query: "ABDDCDCEDGDG", order: "fake", max_results: 60 })
|
||||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
@ -131,7 +131,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
search.invoke(&progress_blk)
|
||||
end
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
|
||||
expect(results[:args]).to eq(
|
||||
{ max_results: 60, search_query: "hello world, sam", status: "public" },
|
||||
)
|
||||
expect(results[:rows].length).to eq(1)
|
||||
|
||||
# it also works with no query
|
||||
@ -174,6 +176,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||
[param[:name], "test"]
|
||||
end
|
||||
end
|
||||
.compact
|
||||
.to_h
|
||||
.symbolize_keys
|
||||
|
||||
|
198
spec/lib/utils/search_spec.rb
Normal file
198
spec/lib/utils/search_spec.rb
Normal file
@ -0,0 +1,198 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Utils::Search do
|
||||
before { SearchIndexer.enable }
|
||||
after { SearchIndexer.disable }
|
||||
|
||||
fab!(:admin)
|
||||
fab!(:user)
|
||||
fab!(:group)
|
||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||
fab!(:tag_funny) { Fabricate(:tag, name: "funny") }
|
||||
fab!(:tag_sad) { Fabricate(:tag, name: "sad") }
|
||||
fab!(:tag_hidden) { Fabricate(:tag, name: "hidden") }
|
||||
fab!(:staff_tag_group) do
|
||||
tag_group = Fabricate.build(:tag_group, name: "Staff only", tag_names: ["hidden"])
|
||||
|
||||
tag_group.permissions = [
|
||||
[Group::AUTO_GROUPS[:staff], TagGroupPermission.permission_types[:full]],
|
||||
]
|
||||
tag_group.save!
|
||||
tag_group
|
||||
end
|
||||
|
||||
fab!(:topic_with_tags) do
|
||||
Fabricate(:topic, category: category, tags: [tag_funny, tag_sad, tag_hidden])
|
||||
end
|
||||
|
||||
fab!(:private_category) do
|
||||
c = Fabricate(:category_with_definition)
|
||||
c.set_permissions(group => :readonly)
|
||||
c.save
|
||||
c
|
||||
end
|
||||
|
||||
describe ".perform_search" do
|
||||
it "returns search results with correct format" do
|
||||
post = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
results =
|
||||
described_class.perform_search(
|
||||
search_query: post.raw,
|
||||
user: post.user.username,
|
||||
current_user: admin,
|
||||
)
|
||||
|
||||
expect(results).to have_key(:args)
|
||||
expect(results).to have_key(:rows)
|
||||
expect(results).to have_key(:column_names)
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
|
||||
it "handles no results" do
|
||||
results =
|
||||
described_class.perform_search(
|
||||
search_query: "NONEXISTENTTERMNOONEWOULDSEARCH",
|
||||
current_user: admin,
|
||||
)
|
||||
|
||||
expect(results[:rows]).to eq([])
|
||||
expect(results[:instruction]).to eq("nothing was found, expand your search")
|
||||
end
|
||||
|
||||
it "returns private results when user has access" do
|
||||
private_post = Fabricate(:post, topic: Fabricate(:topic, category: private_category))
|
||||
|
||||
# Regular user without access
|
||||
results = described_class.perform_search(search_query: private_post.raw, current_user: user)
|
||||
expect(results[:rows].length).to eq(0)
|
||||
|
||||
# Add user to group with access
|
||||
GroupUser.create!(group: group, user: user)
|
||||
|
||||
# Now should find the private post
|
||||
results = described_class.perform_search(search_query: private_post.raw, current_user: user)
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
|
||||
it "properly handles subfolder URLs" do
|
||||
Discourse.stubs(:base_path).returns("/subfolder")
|
||||
|
||||
post = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
results = described_class.perform_search(search_query: post.raw, current_user: admin)
|
||||
|
||||
url_index = results[:column_names].index("url")
|
||||
expect(results[:rows][0][url_index]).to include("/subfolder")
|
||||
end
|
||||
|
||||
it "returns rich topic information" do
|
||||
post = Fabricate(:post, like_count: 1, topic: topic_with_tags)
|
||||
post.topic.update!(views: 100, posts_count: 2, like_count: 10)
|
||||
|
||||
results = described_class.perform_search(search_query: post.raw, current_user: admin)
|
||||
|
||||
row = results[:rows].first
|
||||
|
||||
category_index = results[:column_names].index("category")
|
||||
expect(row[category_index]).to eq("animals > amazing-cat")
|
||||
|
||||
tags_index = results[:column_names].index("tags")
|
||||
expect(row[tags_index]).to eq("funny, sad")
|
||||
|
||||
likes_index = results[:column_names].index("likes")
|
||||
expect(row[likes_index]).to eq(1)
|
||||
|
||||
topic_likes_index = results[:column_names].index("topic_likes")
|
||||
expect(row[topic_likes_index]).to eq(10)
|
||||
|
||||
topic_views_index = results[:column_names].index("topic_views")
|
||||
expect(row[topic_views_index]).to eq(100)
|
||||
|
||||
topic_replies_index = results[:column_names].index("topic_replies")
|
||||
expect(row[topic_replies_index]).to eq(1)
|
||||
end
|
||||
|
||||
context "when using semantic search" do
|
||||
let(:query) { "this is an expanded search" }
|
||||
after do
|
||||
if defined?(DiscourseAi::Embeddings::SemanticSearch)
|
||||
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query)
|
||||
end
|
||||
end
|
||||
|
||||
it "includes semantic search results when enabled" do
|
||||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||
vector_def = Fabricate(:embedding_definition)
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
|
||||
hyde_embedding = [0.049382] * vector_def.dimensions
|
||||
EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding)
|
||||
|
||||
post = Fabricate(:post, topic: topic_with_tags)
|
||||
DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, hyde_embedding, "digest")
|
||||
|
||||
# Using a completely different search query, should still find via semantic search
|
||||
results =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
||||
described_class.perform_search(
|
||||
search_query: "totally different query",
|
||||
current_user: admin,
|
||||
)
|
||||
end
|
||||
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
|
||||
it "can disable semantic search with hyde parameter" do
|
||||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||
vector_def = Fabricate(:embedding_definition)
|
||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
|
||||
embedding = [0.049382] * vector_def.dimensions
|
||||
EmbeddingsGenerationStubs.hugging_face_service(query, embedding)
|
||||
|
||||
post = Fabricate(:post, topic: topic_with_tags)
|
||||
DiscourseAi::Embeddings::Schema.for(Topic).store(post.topic, embedding, "digest")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://test.com/embeddings")
|
||||
.with(body: "{\"inputs\":\"totally different query\",\"truncate\":true}")
|
||||
.to_return(status: 200, body: embedding.to_json)
|
||||
|
||||
results =
|
||||
described_class.perform_search(
|
||||
search_query: "totally different query",
|
||||
hyde: false,
|
||||
current_user: admin,
|
||||
)
|
||||
|
||||
expect(results[:rows].length).to eq(0)
|
||||
end
|
||||
end
|
||||
|
||||
it "passes all search parameters to the results args" do
|
||||
post = Fabricate(:post, topic: topic_with_tags)
|
||||
|
||||
search_params = {
|
||||
search_query: post.raw,
|
||||
category: category.name,
|
||||
user: post.user.username,
|
||||
order: "latest",
|
||||
max_posts: 10,
|
||||
tags: tag_funny.name,
|
||||
before: "2030-01-01",
|
||||
after: "2000-01-01",
|
||||
status: "public",
|
||||
max_results: 15,
|
||||
}
|
||||
|
||||
results = described_class.perform_search(**search_params, current_user: admin)
|
||||
|
||||
expect(results[:args]).to include(search_params)
|
||||
end
|
||||
end
|
||||
end
|
@ -328,4 +328,37 @@ RSpec.describe AiTool do
|
||||
expect(result).to eq(expected)
|
||||
end
|
||||
end
|
||||
|
||||
context "when using the search API" do
|
||||
before { SearchIndexer.enable }
|
||||
after { SearchIndexer.disable }
|
||||
|
||||
it "can perform a discourse search" do
|
||||
# Create a new topic
|
||||
topic = Fabricate(:topic, title: "Test Search Topic", category: Fabricate(:category))
|
||||
post = Fabricate(:post, topic: topic, raw: "This is a test post content, banana")
|
||||
|
||||
# Ensure the topic is indexed
|
||||
SearchIndexer.index(topic, force: true)
|
||||
SearchIndexer.index(post, force: true)
|
||||
|
||||
# Define the tool script
|
||||
script = <<~JS
|
||||
function invoke(params) {
|
||||
return discourse.search({ search_query: params.query });
|
||||
}
|
||||
JS
|
||||
|
||||
# Create the tool and runner
|
||||
tool = create_tool(script: script)
|
||||
runner = tool.runner({ "query" => "banana" }, llm: nil, bot_user: nil, context: {})
|
||||
|
||||
# Invoke the tool and get the results
|
||||
result = runner.invoke
|
||||
|
||||
# Verify the topic is found
|
||||
expect(result["rows"].length).to be > 0
|
||||
expect(result["rows"].first["title"]).to eq("Test Search Topic")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user