FEATURE: add semantic search with hyde bot (#210)
In specific scenarios (no special filters or limits) we will also always include 5 semantic results (at least) with every query. This effectively means that all very wide queries will always return 20 results, regardless of how complex they are. Also: FIX: embedding backfill rake task not working We renamed internals, this corrects the implementation
This commit is contained in:
parent
abe96d5533
commit
615eb8b440
|
@ -15,7 +15,8 @@ module DiscourseAi::AiBot::Commands
|
|||
[
|
||||
Parameter.new(
|
||||
name: "search_query",
|
||||
description: "Search query (correct bad spelling, remove connector words!)",
|
||||
description:
|
||||
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
|
||||
type: "string",
|
||||
),
|
||||
Parameter.new(
|
||||
|
@ -93,6 +94,9 @@ module DiscourseAi::AiBot::Commands
|
|||
}
|
||||
end
|
||||
|
||||
MAX_RESULTS = 20
|
||||
MIN_SEMANTIC_RESULTS = 5
|
||||
|
||||
def process(**search_args)
|
||||
limit = nil
|
||||
|
||||
|
@ -120,12 +124,35 @@ module DiscourseAi::AiBot::Commands
|
|||
)
|
||||
|
||||
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
|
||||
limit ||= 20
|
||||
limit = 20 if limit > 20
|
||||
limit ||= MAX_RESULTS
|
||||
limit = MAX_RESULTS if limit > MAX_RESULTS
|
||||
|
||||
should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled
|
||||
should_try_semantic_search &&= (limit == MAX_RESULTS)
|
||||
should_try_semantic_search &&= (search_args.keys - %i[search_query order]).length == 0
|
||||
should_try_semantic_search &&= (search_args[:search_query].present?)
|
||||
|
||||
limit = limit - MIN_SEMANTIC_RESULTS if should_try_semantic_search
|
||||
|
||||
posts = results&.posts || []
|
||||
posts = posts[0..limit - 1]
|
||||
|
||||
if should_try_semantic_search
|
||||
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
|
||||
topic_ids = Set.new(posts.map(&:topic_id))
|
||||
|
||||
semantic_search
|
||||
.search_for_topics(search_args[:search_query])
|
||||
.each do |post|
|
||||
next if topic_ids.include?(post.topic_id)
|
||||
|
||||
topic_ids << post.topic_id
|
||||
posts << post
|
||||
|
||||
break if posts.length >= MAX_RESULTS
|
||||
end
|
||||
end
|
||||
|
||||
@last_num_results = posts.length
|
||||
|
||||
if posts.blank?
|
||||
|
|
|
@ -3,11 +3,9 @@
|
|||
desc "Backfill embeddings for all topics"
|
||||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
table_name = vector_rep.table_name
|
||||
|
||||
Topic
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/openai_completions_inference_stubs"
|
||||
require_relative "../../../../support/embeddings_generation_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
|
@ -19,6 +20,43 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
|||
expect(results[:rows]).to eq([])
|
||||
end
|
||||
|
||||
describe "semantic search" do
|
||||
let (:query) {
|
||||
"this is an expanded search"
|
||||
}
|
||||
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
||||
|
||||
it "supports semantic search when enabled" do
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)),
|
||||
)
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
query,
|
||||
hyde_embedding,
|
||||
)
|
||||
|
||||
post1 = Fabricate(:post)
|
||||
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
|
||||
.any_instance
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns([post1.topic_id])
|
||||
|
||||
results = search.process(search_query: "hello world, sam")
|
||||
|
||||
expect(results[:args]).to eq({ search_query: "hello world, sam" })
|
||||
expect(results[:rows].length).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
it "supports subfolder properly" do
|
||||
Discourse.stubs(:base_path).returns("/subfolder")
|
||||
|
||||
|
|
Loading…
Reference in New Issue