FEATURE: allow embedding based search without hyde (#777)

This allows callers of embedding based search to bypass hyde.

Hyde will expand the search term using an LLM, but if an LLM is
performing the search we can skip this expansion.

It also introduced some tests for the controller which we did not have
This commit is contained in:
Sam 2024-08-28 14:17:34 +10:00 committed by GitHub
parent 72607c3560
commit 0687ec75c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 106 additions and 18 deletions

View File

@ -9,6 +9,7 @@ module DiscourseAi
def search def search
query = params[:q].to_s query = params[:q].to_s
skip_hyde = params[:hyde].downcase.to_s == "false" || params[:hyde].to_s == "0"
if query.length < SiteSetting.min_search_term_length if query.length < SiteSetting.min_search_term_length
raise Discourse::InvalidParameters.new(:q) raise Discourse::InvalidParameters.new(:q)
@ -31,7 +32,7 @@ module DiscourseAi
hijack do hijack do
semantic_search semantic_search
.search_for_topics(query) .search_for_topics(query, _page = 1, hyde: !skip_hyde)
.each { |topic_post| grouped_results.add(topic_post) } .each { |topic_post| grouped_results.add(topic_post) }
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results) render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
@ -39,6 +40,9 @@ module DiscourseAi
end end
def quick_search def quick_search
# this search function searches posts (vs: topics)
# it requires post embeddings and a reranker
# it will not perform a hyde expantion
query = params[:q].to_s query = params[:q].to_s
if query.length < SiteSetting.min_search_term_length if query.length < SiteSetting.min_search_term_length

View File

@ -11,6 +11,7 @@ module DiscourseAi
Discourse.cache.delete(hyde_key) Discourse.cache.delete(hyde_key)
Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}") Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}")
Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}")
end end
def initialize(guardian) def initialize(guardian)
@ -29,19 +30,14 @@ module DiscourseAi
Discourse.cache.read(embedding_key).present? Discourse.cache.read(embedding_key).present?
end end
def search_for_topics(query, page = 1) def vector_rep
max_results_per_page = 100 @vector_rep ||=
limit = [Search.per_filter, max_results_per_page].min + 1 DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
offset = (page - 1) * limit DiscourseAi::Embeddings::Strategies::Truncation.new,
search = Search.new(query, { guardian: guardian }) )
search_term = search.term end
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
def hyde_embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model) hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model)
@ -57,14 +53,34 @@ module DiscourseAi
.cache .cache
.fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) } .fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) }
hypothetical_post_embedding = Discourse
Discourse .cache
.cache .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } end
def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
end
def search_for_topics(query, page = 1, hyde: true)
max_results_per_page = 100
limit = [Search.per_filter, max_results_per_page].min + 1
offset = (page - 1) * limit
search = Search.new(query, { guardian: guardian })
search_term = search.term
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term)
candidate_topic_ids = candidate_topic_ids =
vector_rep.asymmetric_topics_similarity_search( vector_rep.asymmetric_topics_similarity_search(
hypothetical_post_embedding, search_embedding,
limit: limit, limit: limit,
offset: offset, offset: offset,
) )

View File

@ -0,0 +1,68 @@
# frozen_string_literal: true
describe DiscourseAi::Embeddings::EmbeddingsController do
context "when performing a topic search" do
before do
SiteSetting.min_search_term_length = 3
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
SearchIndexer.enable
end
fab!(:category)
fab!(:subcategory) { Fabricate(:category, parent_category_id: category.id) }
fab!(:topic)
fab!(:post) { Fabricate(:post, topic: topic) }
fab!(:topic_in_subcategory) { Fabricate(:topic, category: subcategory) }
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
def index(topic)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
status: 200,
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
)
vector_rep.generate_representation_from(topic)
end
def stub_embedding(query)
embedding = [0.049382] * 1536
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
end
it "returns results correctly when performing a non Hyde search" do
index(topic)
index(topic_in_subcategory)
query = "test"
stub_embedding(query)
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
expect(response.status).to eq(200)
expect(response.parsed_body["topics"].map { |t| t["id"] }).to contain_exactly(
topic.id,
topic_in_subcategory.id,
)
end
it "is able to filter to a specific category (including sub categories)" do
index(topic)
index(topic_in_subcategory)
query = "test category:#{category.slug}"
stub_embedding("test")
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
expect(response.status).to eq(200)
expect(response.parsed_body["topics"].map { |t| t["id"] }).to eq([topic_in_subcategory.id])
end
end
end