From 615eb8b440f1edec6f52d69e137213cc03f5b1c2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 7 Sep 2023 13:25:26 +1000 Subject: [PATCH] FEATURE: add semantic search with hyde bot (#210) In specific scenarios (no special filters or limits) we will also always include 5 semantic results (at least) with every query. This effectively means that all very wide queries will always return 20 results, regardless of how complex they are. Also: FIX: embedding backfill rake task not working We renamed internals, this corrects the implementation --- lib/modules/ai_bot/commands/search_command.rb | 33 ++++++++++++++-- lib/tasks/modules/embeddings/database.rake | 4 +- .../ai_bot/commands/search_command_spec.rb | 38 +++++++++++++++++++ 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/lib/modules/ai_bot/commands/search_command.rb b/lib/modules/ai_bot/commands/search_command.rb index 330a7d97..4df65b87 100644 --- a/lib/modules/ai_bot/commands/search_command.rb +++ b/lib/modules/ai_bot/commands/search_command.rb @@ -15,7 +15,8 @@ module DiscourseAi::AiBot::Commands [ Parameter.new( name: "search_query", - description: "Search query (correct bad spelling, remove connector words!)", + description: + "Specific keywords to search for, space seperated (correct bad spelling, remove connector words)", type: "string", ), Parameter.new( @@ -93,6 +94,9 @@ module DiscourseAi::AiBot::Commands } end + MAX_RESULTS = 20 + MIN_SEMANTIC_RESULTS = 5 + def process(**search_args) limit = nil @@ -120,12 +124,35 @@ module DiscourseAi::AiBot::Commands ) # let's be frugal with tokens, 50 results is too much and stuff gets cut off - limit ||= 20 - limit = 20 if limit > 20 + limit ||= MAX_RESULTS + limit = MAX_RESULTS if limit > MAX_RESULTS + + should_try_semantic_search = SiteSetting.ai_embeddings_semantic_search_enabled + should_try_semantic_search &&= (limit == MAX_RESULTS) + should_try_semantic_search &&= (search_args.keys - %i[search_query order]).length == 0 + should_try_semantic_search &&= (search_args[:search_query].present?) + + limit = limit - MIN_SEMANTIC_RESULTS if should_try_semantic_search posts = results&.posts || [] posts = posts[0..limit - 1] + if should_try_semantic_search + semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(Guardian.new()) + topic_ids = Set.new(posts.map(&:topic_id)) + + semantic_search + .search_for_topics(search_args[:search_query]) + .each do |post| + next if topic_ids.include?(post.topic_id) + + topic_ids << post.topic_id + posts << post + + break if posts.length >= MAX_RESULTS + end + end + @last_num_results = posts.length if posts.blank? diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index 4ee260cb..875682e7 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -3,11 +3,9 @@ desc "Backfill embeddings for all topics" task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args| public_categories = Category.where(read_restricted: false).pluck(:id) - manager = DiscourseAi::Embeddings::Manager.new(Topic.first) strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) table_name = vector_rep.table_name Topic diff --git a/spec/lib/modules/ai_bot/commands/search_command_spec.rb b/spec/lib/modules/ai_bot/commands/search_command_spec.rb index 9b4333ab..4071c0b0 100644 --- a/spec/lib/modules/ai_bot/commands/search_command_spec.rb +++ b/spec/lib/modules/ai_bot/commands/search_command_spec.rb @@ -1,6 +1,7 @@ #frozen_string_literal: true require_relative "../../../../support/openai_completions_inference_stubs" +require_relative "../../../../support/embeddings_generation_stubs" RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } @@ -19,6 +20,43 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do expect(results[:rows]).to eq([]) end + describe "semantic search" do + let (:query) { + "this is an expanded search" + } + after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) } + + it "supports semantic search when enabled" do + SiteSetting.ai_embeddings_semantic_search_enabled = true + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + + WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + status: 200, + body: JSON.dump(OpenAiCompletionsInferenceStubs.response(query)), + ) + + hyde_embedding = [0.049382, 0.9999] + EmbeddingsGenerationStubs.discourse_service( + SiteSetting.ai_embeddings_model, + query, + hyde_embedding, + ) + + post1 = Fabricate(:post) + search = described_class.new(bot_user: bot_user, post: post1, args: nil) + + DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 + .any_instance + .expects(:asymmetric_topics_similarity_search) + .returns([post1.topic_id]) + + results = search.process(search_query: "hello world, sam") + + expect(results[:args]).to eq({ search_query: "hello world, sam" }) + expect(results[:rows].length).to eq(1) + end + end + it "supports subfolder properly" do Discourse.stubs(:base_path).returns("/subfolder")