FEATURE: allow embedding based search without hyde (#777)
This allows callers of embedding based search to bypass hyde. Hyde will expand the search term using an LLM, but if an LLM is performing the search we can skip this expansion. It also introduced some tests for the controller which we did not have
This commit is contained in:
parent
72607c3560
commit
0687ec75c3
|
@ -9,6 +9,7 @@ module DiscourseAi
|
|||
|
||||
def search
|
||||
query = params[:q].to_s
|
||||
skip_hyde = params[:hyde].downcase.to_s == "false" || params[:hyde].to_s == "0"
|
||||
|
||||
if query.length < SiteSetting.min_search_term_length
|
||||
raise Discourse::InvalidParameters.new(:q)
|
||||
|
@ -31,7 +32,7 @@ module DiscourseAi
|
|||
|
||||
hijack do
|
||||
semantic_search
|
||||
.search_for_topics(query)
|
||||
.search_for_topics(query, _page = 1, hyde: !skip_hyde)
|
||||
.each { |topic_post| grouped_results.add(topic_post) }
|
||||
|
||||
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||
|
@ -39,6 +40,9 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def quick_search
|
||||
# this search function searches posts (vs: topics)
|
||||
# it requires post embeddings and a reranker
|
||||
# it will not perform a hyde expantion
|
||||
query = params[:q].to_s
|
||||
|
||||
if query.length < SiteSetting.min_search_term_length
|
||||
|
|
|
@ -11,6 +11,7 @@ module DiscourseAi
|
|||
|
||||
Discourse.cache.delete(hyde_key)
|
||||
Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}")
|
||||
Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}")
|
||||
end
|
||||
|
||||
def initialize(guardian)
|
||||
|
@ -29,19 +30,14 @@ module DiscourseAi
|
|||
Discourse.cache.read(embedding_key).present?
|
||||
end
|
||||
|
||||
def search_for_topics(query, page = 1)
|
||||
max_results_per_page = 100
|
||||
limit = [Search.per_filter, max_results_per_page].min + 1
|
||||
offset = (page - 1) * limit
|
||||
search = Search.new(query, { guardian: guardian })
|
||||
search_term = search.term
|
||||
|
||||
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
def vector_rep
|
||||
@vector_rep ||=
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
|
||||
DiscourseAi::Embeddings::Strategies::Truncation.new,
|
||||
)
|
||||
end
|
||||
|
||||
def hyde_embedding(search_term)
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||
hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model)
|
||||
|
||||
|
@ -57,14 +53,34 @@ module DiscourseAi
|
|||
.cache
|
||||
.fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) }
|
||||
|
||||
hypothetical_post_embedding =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
|
||||
end
|
||||
|
||||
def embedding(search_term)
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
|
||||
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
|
||||
end
|
||||
|
||||
def search_for_topics(query, page = 1, hyde: true)
|
||||
max_results_per_page = 100
|
||||
limit = [Search.per_filter, max_results_per_page].min + 1
|
||||
offset = (page - 1) * limit
|
||||
search = Search.new(query, { guardian: guardian })
|
||||
search_term = search.term
|
||||
|
||||
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||
|
||||
search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term)
|
||||
|
||||
candidate_topic_ids =
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
hypothetical_post_embedding,
|
||||
search_embedding,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Embeddings::EmbeddingsController do
|
||||
context "when performing a topic search" do
|
||||
before do
|
||||
SiteSetting.min_search_term_length = 3
|
||||
SiteSetting.ai_embeddings_model = "text-embedding-3-small"
|
||||
DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test")
|
||||
SearchIndexer.enable
|
||||
end
|
||||
|
||||
fab!(:category)
|
||||
fab!(:subcategory) { Fabricate(:category, parent_category_id: category.id) }
|
||||
|
||||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, topic: topic) }
|
||||
|
||||
fab!(:topic_in_subcategory) { Fabricate(:topic, category: subcategory) }
|
||||
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
|
||||
|
||||
def index(topic)
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
|
||||
)
|
||||
|
||||
vector_rep.generate_representation_from(topic)
|
||||
end
|
||||
|
||||
def stub_embedding(query)
|
||||
embedding = [0.049382] * 1536
|
||||
EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding)
|
||||
end
|
||||
|
||||
it "returns results correctly when performing a non Hyde search" do
|
||||
index(topic)
|
||||
index(topic_in_subcategory)
|
||||
|
||||
query = "test"
|
||||
stub_embedding(query)
|
||||
|
||||
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["topics"].map { |t| t["id"] }).to contain_exactly(
|
||||
topic.id,
|
||||
topic_in_subcategory.id,
|
||||
)
|
||||
end
|
||||
|
||||
it "is able to filter to a specific category (including sub categories)" do
|
||||
index(topic)
|
||||
index(topic_in_subcategory)
|
||||
|
||||
query = "test category:#{category.slug}"
|
||||
stub_embedding("test")
|
||||
|
||||
get "/discourse-ai/embeddings/semantic-search.json?q=#{query}&hyde=false"
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["topics"].map { |t| t["id"] }).to eq([topic_in_subcategory.id])
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue