diff --git a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb index b5ee5fad..753ec18f 100644 --- a/app/controllers/discourse_ai/embeddings/embeddings_controller.rb +++ b/app/controllers/discourse_ai/embeddings/embeddings_controller.rb @@ -9,6 +9,7 @@ module DiscourseAi def search query = params[:q].to_s + skip_hyde = params[:hyde].downcase.to_s == "false" || params[:hyde].to_s == "0" if query.length < SiteSetting.min_search_term_length raise Discourse::InvalidParameters.new(:q) @@ -31,7 +32,7 @@ module DiscourseAi hijack do semantic_search - .search_for_topics(query) + .search_for_topics(query, _page = 1, hyde: !skip_hyde) .each { |topic_post| grouped_results.add(topic_post) } render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) @@ -39,6 +40,9 @@ module DiscourseAi end def quick_search + # this search function searches posts (vs: topics) + # it requires post embeddings and a reranker + # it will not perform a hyde expantion query = params[:q].to_s if query.length < SiteSetting.min_search_term_length diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 967edf95..779f644a 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -11,6 +11,7 @@ module DiscourseAi Discourse.cache.delete(hyde_key) Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}") + Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}") end def initialize(guardian) @@ -29,19 +30,14 @@ module DiscourseAi Discourse.cache.read(embedding_key).present? end - def search_for_topics(query, page = 1) - max_results_per_page = 100 - limit = [Search.per_filter, max_results_per_page].min + 1 - offset = (page - 1) * limit - search = Search.new(query, { guardian: guardian }) - search_term = search.term - - return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length - - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + def vector_rep + @vector_rep ||= + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation( + DiscourseAi::Embeddings::Strategies::Truncation.new, + ) + end + def hyde_embedding(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term) hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model) @@ -57,14 +53,34 @@ module DiscourseAi .cache .fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) } - hypothetical_post_embedding = - Discourse - .cache - .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } + Discourse + .cache + .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } + end + + def embedding(search_term) + digest = OpenSSL::Digest::SHA1.hexdigest(search_term) + embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model) + + Discourse + .cache + .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) } + end + + def search_for_topics(query, page = 1, hyde: true) + max_results_per_page = 100 + limit = [Search.per_filter, max_results_per_page].min + 1 + offset = (page - 1) * limit + search = Search.new(query, { guardian: guardian }) + search_term = search.term + + return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length + + search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term) candidate_topic_ids = vector_rep.asymmetric_topics_similarity_search( - hypothetical_post_embedding, + search_embedding, limit: limit, offset: offset, ) diff --git a/spec/requests/embeddings/embeddings_controller_spec.rb b/spec/requests/embeddings/embeddings_controller_spec.rb new file mode 100644 index 00000000..a62b9e01 --- /dev/null +++ b/spec/requests/embeddings/embeddings_controller_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +describe DiscourseAi::Embeddings::EmbeddingsController do + context "when performing a topic search" do + before do + SiteSetting.min_search_term_length = 3 + SiteSetting.ai_embeddings_model = "text-embedding-3-small" + DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test") + SearchIndexer.enable + end + + fab!(:category) + fab!(:subcategory) { Fabricate(:category, parent_category_id: category.id) } + + fab!(:topic) + fab!(:post) { Fabricate(:post, topic: topic) } + + fab!(:topic_in_subcategory) { Fabricate(:topic, category: subcategory) } + fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } + + def index(topic) + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + vector_rep = + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + + stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( + status: 200, + body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }), + ) + + vector_rep.generate_representation_from(topic) + end + + def stub_embedding(query) + embedding = [0.049382] * 1536 + EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding) + end + + it "returns results correctly when performing a non Hyde search" do + index(topic) + index(topic_in_subcategory) + + query = "test" + stub_embedding(query) + + get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false" + + expect(response.status).to eq(200) + expect(response.parsed_body["topics"].map { |t| t["id"] }).to contain_exactly( + topic.id, + topic_in_subcategory.id, + ) + end + + it "is able to filter to a specific category (including sub categories)" do + index(topic) + index(topic_in_subcategory) + + query = "test category:#{category.slug}" + stub_embedding("test") + + get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false" + + expect(response.status).to eq(200) + expect(response.parsed_body["topics"].map { |t| t["id"] }).to eq([topic_in_subcategory.id]) + end + end +end