FEATURE: add semantic search with hyde bot (#210)

In specific scenarios (no special filters or limits) we will also
always include 5 semantic results (at least) with every query.

This effectively means that all very wide queries will always return
20 results, regardless of how complex they are.

Also: 

FIX: embedding backfill rake task not working
We renamed internals, this corrects the implementation
This commit is contained in:
Sam 2023-09-07 13:25:26 +10:00 committed by GitHub
parent abe96d5533
commit 615eb8b440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 6 deletions

View File

@ -15,7 +15,8 @@ module DiscourseAi::AiBot::Commands
[
Parameter.new(
name: "search_query",
description: "Search query (correct bad spelling, remove connector words!)",
description:
"Specific keywords to search for, space seperated (correct bad spelling, remove connector words)",
type: "string",
),
Parameter.new(
@ -93,6 +94,9 @@ module DiscourseAi::AiBot::Commands
}
end
MAX_RESULTS = 20
MIN_SEMANTIC_RESULTS = 5
def process(**search_args)
limit = nil
@ -120,12 +124,35 @@ module DiscourseAi::AiBot::Commands
)
# let's be frugal with tokens, 50 results is too much and stuff gets cut off
limit ||= 20
limit = 20 if limit > 20
limit ||= MAX_RESULTS
limit = MAX_RESULTS if limit > MAX_RESULTS
should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled
should_try_semantic_search &&= (limit == MAX_RESULTS)
should_try_semantic_search &&= (search_args.keys - %i[search_query order]).length == 0
should_try_semantic_search &&= (search_args[:search_query].present?)
limit = limit - MIN_SEMANTIC_RESULTS if should_try_semantic_search
posts = results&.posts || []
posts = posts[0..limit - 1]
if should_try_semantic_search
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new())
topic_ids = Set.new(posts.map(&:topic_id))
semantic_search
.search_for_topics(search_args[:search_query])
.each do |post|
next if topic_ids.include?(post.topic_id)
topic_ids << post.topic_id
posts << post
break if posts.length >= MAX_RESULTS
end
end
@last_num_results = posts.length
if posts.blank?

View File

@ -3,11 +3,9 @@
desc "Backfill embeddings for all topics"
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id)
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
table_name = vector_rep.table_name
Topic

View File

@ -1,6 +1,7 @@
#frozen_string_literal: true
require_relative "../../../../support/openai_completions_inference_stubs"
require_relative "../../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
@ -19,6 +20,43 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
expect(results[:rows]).to eq([])
end
describe "semantic search" do
let (:query) {
"this is an expanded search"
}
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
it "supports semantic search when enabled" do
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)),
)
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
query,
hyde_embedding,
)
post1 = Fabricate(:post)
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns([post1.topic_id])
results = search.process(search_query: "hello world, sam")
expect(results[:args]).to eq({ search_query: "hello world, sam" })
expect(results[:rows].length).to eq(1)
end
end
it "supports subfolder properly" do
Discourse.stubs(:base_path).returns("/subfolder")