discourse-ai/lib/embeddings/semantic_search.rb
Sam 4923837165
FIX: Llm selector / forced tools / search tool (#862)
* FIX: Llm selector / forced tools / search tool


This fixes a few issues:

1. When search was not finding any semantic results we would break the tool
2. Gemin / Anthropic models did not implement forced tools previously despite it being an API option
3. Mechanics around displaying llm selector were not right. If you disabled LLM selector server side persona PM did not work correctly.
4. Disabling native tools for anthropic model moved out of a site setting. This deliberately does not migrate cause this feature is really rare to need now, people who had it set probably did not need it.
5. Updates anthropic model names to latest release

* linting

* fix a couple of tests I missed

* clean up conditional
2024-10-25 06:24:53 +11:00

217 lines
7.0 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Embeddings
class SemanticSearch
def self.clear_cache_for(query)
digest = OpenSSL::Digest::SHA1.hexdigest(query)
hyde_key =
"semantic-search-#{digest}-#{SiteSetting.ai_embeddings_semantic_search_hyde_model}"
Discourse.cache.delete(hyde_key)
Discourse.cache.delete("#{hyde_key}-#{SiteSetting.ai_embeddings_model}")
Discourse.cache.delete("-#{SiteSetting.ai_embeddings_model}")
end
def initialize(guardian)
@guardian = guardian
end
def cached_query?(query)
digest = OpenSSL::Digest::SHA1.hexdigest(query)
embedding_key =
build_embedding_key(
digest,
SiteSetting.ai_embeddings_semantic_search_hyde_model,
SiteSetting.ai_embeddings_model,
)
Discourse.cache.read(embedding_key).present?
end
def vector_rep
@vector_rep ||=
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end
def hyde_embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
hyde_key = build_hyde_key(digest, SiteSetting.ai_embeddings_semantic_search_hyde_model)
embedding_key =
build_embedding_key(
digest,
SiteSetting.ai_embeddings_semantic_search_hyde_model,
SiteSetting.ai_embeddings_model,
)
hypothetical_post =
Discourse
.cache
.fetch(hyde_key, expires_in: 1.week) { hypothetical_post_from(search_term) }
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
end
def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
end
# this ensures the candidate topics are over selected
# that way we have a much better chance of finding topics
# if the user filtered the results or index is a bit out of date
OVER_SELECTION_FACTOR = 4
def search_for_topics(query, page = 1, hyde: true)
max_results_per_page = 100
limit = [Search.per_filter, max_results_per_page].min + 1
offset = (page - 1) * limit
search = Search.new(query, { guardian: guardian })
search_term = search.term
if search_term.blank? || search_term.length < SiteSetting.min_search_term_length
return Post.none
end
search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term)
over_selection_limit = limit * OVER_SELECTION_FACTOR
candidate_topic_ids =
vector_rep.asymmetric_topics_similarity_search(
search_embedding,
limit: over_selection_limit,
offset: offset,
)
semantic_results =
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(topic_id: candidate_topic_ids, post_number: 1)
.order("array_position(ARRAY#{candidate_topic_ids}, posts.topic_id)")
.limit(limit)
query_filter_results = search.apply_filters(semantic_results)
guardian.filter_allowed_categories(query_filter_results)
end
def quick_search(query)
max_semantic_results_per_page = 100
search = Search.new(query, { guardian: guardian })
search_term = search.term
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key =
build_embedding_key(
digest,
SiteSetting.ai_embeddings_semantic_search_hyde_model,
SiteSetting.ai_embeddings_model,
)
search_term_embedding =
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true)
end
candidate_post_ids =
vector_rep.asymmetric_posts_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,
offset: 0,
)
semantic_results =
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(id: candidate_post_ids)
.order("array_position(ARRAY#{candidate_post_ids}, posts.id)")
filtered_results = search.apply_filters(semantic_results)
rerank_posts_payload =
filtered_results
.map(&:cooked)
.map { Nokogiri::HTML5.fragment(_1).text }
.map { _1.truncate(2000, omission: "") }
reranked_results =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
search_term,
rerank_posts_payload,
)
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
reranked_semantic_results =
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(id: reordered_ids)
.order("array_position(ARRAY#{reordered_ids}, posts.id)")
guardian.filter_allowed_categories(reranked_semantic_results)
end
def hypothetical_post_from(search_term)
prompt = DiscourseAi::Completions::Prompt.new(<<~TEXT.strip)
You are a content creator for a forum. The forum description is as follows:
#{SiteSetting.title}
#{SiteSetting.site_description}
Put the forum post between <ai></ai> tags.
TEXT
prompt.push(type: :user, content: <<~TEXT.strip)
Using this description, write a forum post about the subject inside the <input></input> XML tags:
<input>#{search_term}</input>
TEXT
llm_response =
DiscourseAi::Completions::Llm.proxy(
SiteSetting.ai_embeddings_semantic_search_hyde_model,
).generate(prompt, user: @guardian.user, feature_name: "semantic_search_hyde")
Nokogiri::HTML5.fragment(llm_response).at("ai")&.text.presence || llm_response
end
private
attr_reader :guardian
def build_hyde_key(digest, hyde_model)
"semantic-search-#{digest}-#{hyde_model}"
end
def build_embedding_key(digest, hyde_model, embedding_model)
"#{build_hyde_key(digest, hyde_model)}-#{embedding_model}"
end
end
end
end