From 0687ec75c3ff135b960af236ef1d37a3779a843d Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 28 Aug 2024 14:17:34 +1000 Subject: [PATCH] FEATURE: allow embedding based search without hyde (#777) This allows callers of embedding based search to bypass hyde. Hyde will expand the search term using an LLM, but if an LLM is performing the search we can skip this expansion. It also introduced some tests for the controller which we did not have --- .../embeddings/embeddings_controller.rb | 6 +- lib/embeddings/semantic_search.rb | 50 +++++++++----- .../embeddings/embeddings_controller_spec.rb | 68 +++++++++++++++++++ 3 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 spec/requests/embeddings/embeddings_controller_spec.rb diff --git a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb index b5ee5fad..753ec18f 100644 --- a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb +++ b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb @@ -9,6 +9,7 @@ module DiscourseAi def search query = params[:q].to_s + skip_hyde = params[:hyde].downcase.to_s == "false" || params[:hyde].to_s == "0" if query.length < SiteSetting.min_search_term_length raise Discourse::InvalidParameters.new(:q) @@ -31,7 +32,7 @@ module DiscourseAi hijack do semantic_search - .search_for_topics(query) + .search_for_topics(query, _page = 1, hyde: !skip_hyde) .each { |topic_post| grouped_results.add(topic_post) } render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) @@ -39,6 +40,9 @@ module DiscourseAi end def quick_search + # this search function searches posts (vs: topics) + # it requires post embeddings and a reranker + # it will not perform a hyde expantion query = params[:q].to_s if query.length < SiteSetting.min_search_term_length diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 967edf95..779f644a 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -11,6 +11,7 @@ module DiscourseAi Discourse.cache.delete(hyde_key) Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}") + Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}") end def initialize(guardian) @@ -29,19 +30,14 @@ module DiscourseAi Discourse.cache.read(embedding_key).present? end - def search_for_topics(query, page = 1) - max_results_per_page = 100 - limit = [Search.per_filter, max_results_per_page].min + 1 - offset = (page - 1) * limit - search = Search.new(query, { guardian: guardian }) - search_term = search.term - - return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length - - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + def vector_rep + @vector_rep ||= + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation( + DiscourseAi::Embeddings::Strategies::Truncation.new, + ) + end + def hyde_embedding(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term) hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model) @@ -57,14 +53,34 @@ module DiscourseAi .cache .fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) } - hypothetical_post_embedding = - Discourse - .cache - .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } + Discourse + .cache + .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } + end + + def embedding(search_term) + digest = OpenSSL::Digest::SHA1.hexdigest(search_term) + embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model) + + Discourse + .cache + .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) } + end + + def search_for_topics(query, page = 1, hyde: true) + max_results_per_page = 100 + limit = [Search.per_filter, max_results_per_page].min + 1 + offset = (page - 1) * limit + search = Search.new(query, { guardian: guardian }) + search_term = search.term + + return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length + + search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term) candidate_topic_ids = vector_rep.asymmetric_topics_similarity_search( - hypothetical_post_embedding, + search_embedding, limit: limit, offset: offset, ) diff --git a/spec/requests/embeddings/embeddings_controller_spec.rb b/spec/requests/embeddings/embeddings_controller_spec.rb new file mode 100644 index 00000000..a62b9e01 --- /dev/null +++ b/spec/requests/embeddings/embeddings_controller_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +describe DiscourseAi::Embeddings::EmbeddingsController do + context "when performing a topic search" do + before do + SiteSetting.min_search_term_length = 3 + SiteSetting.ai_embeddings_model = "text-embedding-3-small" + DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test") + SearchIndexer.enable + end + + fab!(:category) + fab!(:subcategory) { Fabricate(:category, parent_category_id: category.id) } + + fab!(:topic) + fab!(:post) { Fabricate(:post, topic: topic) } + + fab!(:topic_in_subcategory) { Fabricate(:topic, category: subcategory) } + fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } + + def index(topic) + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + + stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( + status: 200, + body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }), + ) + + vector_rep.generate_representation_from(topic) + end + + def stub_embedding(query) + embedding = [0.049382] * 1536 + EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding) + end + + it "returns results correctly when performing a non Hyde search" do + index(topic) + index(topic_in_subcategory) + + query = "test" + stub_embedding(query) + + get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false" + + expect(response.status).to eq(200) + expect(response.parsed_body["topics"].map { |t| t["id"] }).to contain_exactly( + topic.id, + topic_in_subcategory.id, + ) + end + + it "is able to filter to a specific category (including sub categories)" do + index(topic) + index(topic_in_subcategory) + + query = "test category:#{category.slug}" + stub_embedding("test") + + get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false" + + expect(response.status).to eq(200) + expect(response.parsed_body["topics"].map { |t| t["id"] }).to eq([topic_in_subcategory.id]) + end + end +end