FEATURE: AI Quick Semantic Search (#501)

This PR adds AI semantic search to the search pop available on every page.

It depends on several new and optional settings, like per post embeddings and a reranker model, so this is an experimental endeavour.


---------

Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
Keegan George 2024-03-08 08:02:50 -08:00 committed by GitHub
parent 114b96f2b4
commit b515b4f66d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 355 additions and 39 deletions

View File

@ -36,6 +36,34 @@ module DiscourseAi
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
end
end
def quick_search
query = params[:q].to_s
if query.length < SiteSetting.min_search_term_length
raise Discourse::InvalidParameters.new(:q)
end
grouped_results =
Search::GroupedSearchResults.new(
type_filter: SEMANTIC_SEARCH_TYPE,
term: query,
search_context: guardian,
use_pg_headlines_for_excerpt: false,
)
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
if !semantic_search.cached_query?(query)
RateLimiter.new(current_user, "semantic-search", 60, 1.minutes).performed!
end
hijack do
semantic_search.quick_search(query).each { |topic_post| grouped_results.add(topic_post) }
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
end
end
end
end
end

View File

@ -0,0 +1,15 @@
import Component from "@glimmer/component";
import { inject as service } from "@ember/service";
import loadingSpinner from "discourse/helpers/loading-spinner";
export default class AiQuickSearchLoader extends Component {
@service quickSearch;
<template>
{{#if this.quickSearch.loading}}
<div class="ai-quick-search-spinner">
{{loadingSpinner}}
</div>
{{/if}}
</template>
}

View File

@ -0,0 +1,76 @@
import Component from "@glimmer/component";
import { action } from "@ember/object";
import { inject as service } from "@ember/service";
import AssistantItem from "discourse/components/search-menu/results/assistant-item";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { isValidSearchTerm, translateResults } from "discourse/lib/search";
import i18n from "discourse-common/helpers/i18n";
export default class AiQuickSemanticSearch extends Component {
static shouldRender(_args, { siteSettings }) {
return siteSettings.ai_embeddings_semantic_quick_search_enabled;
}
@service search;
@service quickSearch;
@service siteSettings;
@action
async searchTermChanged() {
if (!this.search.activeGlobalSearchTerm) {
this.search.noResults = false;
this.search.results = {};
this.quickSearch.loading = false;
this.quickSearch.invalidTerm = false;
} else if (
!isValidSearchTerm(this.search.activeGlobalSearchTerm, this.siteSettings)
) {
this.search.noResults = true;
this.search.results = {};
this.quickSearch.loading = false;
this.quickSearch.invalidTerm = true;
return;
} else {
await this.performSearch();
}
}
async performSearch() {
this.quickSearch.loading = true;
this.quickSearch.invalidTerm = false;
try {
const results = await ajax(`/discourse-ai/embeddings/quick-search`, {
data: {
q: this.search.activeGlobalSearchTerm,
},
});
const searchResults = await translateResults(results);
if (searchResults) {
this.search.noResults = results.resultTypes.length === 0;
this.search.results = searchResults;
}
} catch (error) {
popupAjaxError(error);
} finally {
this.quickSearch.loading = false;
}
}
<template>
{{yield}}
{{#if this.search.activeGlobalSearchTerm}}
<AssistantItem
@suffix={{i18n "discourse_ai.embeddings.quick_search.suffix"}}
@icon="discourse-sparkles"
@closeSearchMenu={{@closeSearchMenu}}
@searchTermChanged={{this.searchTermChanged}}
@suggestionKeyword={{@suggestionKeyword}}
/>
{{/if}}
</template>
}

View File

@ -0,0 +1,31 @@
import Component from "@glimmer/component";
import { inject as service } from "@ember/service";
import { isValidSearchTerm } from "discourse/lib/search";
import i18n from "discourse-common/helpers/i18n";
export default class AiQuickSearchInfo extends Component {
@service search;
@service siteSettings;
@service quickSearch;
get termTooShort() {
// We check the validity again here because the input may have changed
// since the last time we checked, so we may want to stop showing the error
const validity = !isValidSearchTerm(
this.search.activeGlobalSearchTerm,
this.siteSettings
);
return (
validity &&
this.quickSearch.invalidTerm &&
this.search.activeGlobalSearchTerm?.length > 0
);
}
<template>
{{#if this.termTooShort}}
<div class="no-results">{{i18n "search.too_short"}}</div>
{{/if}}
</template>
}

View File

@ -0,0 +1,7 @@
import { tracked } from "@glimmer/tracking";
import Service from "@ember/service";
export default class QuickSearch extends Service {
@tracked loading = false;
@tracked invalidTerm = false;
}

View File

@ -72,3 +72,12 @@
}
}
}
// Hides other buttons and only shows loader
// while AI quick search is in progress
.search-input {
.ai-quick-search-spinner ~ a.clear-search,
.ai-quick-search-spinner ~ a.show-advanced-search {
display: none;
}
}

View File

@ -211,6 +211,8 @@ en:
none: "Sorry, our AI search found no matching topics."
new: "Press 'Search' to begin looking for new results with AI"
ai_generated_result: "Search result found using AI"
quick_search:
suffix: "in all topics and posts with AI"
ai_bot:
pm_warning: "AI chatbot messages are monitored regularly by moderators."

View File

@ -84,6 +84,7 @@ en:
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
ai_embeddings_per_post_enabled: Generate embeddings for each post

View File

@ -12,6 +12,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
get "semantic-search" => "embeddings#search"
get "quick-search" => "embeddings#quick_search"
end
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do

View File

@ -137,6 +137,12 @@ discourse_ai:
default: ""
hidden: true
ai_hugging_face_tei_api_key: ""
ai_hugging_face_tei_reranker_endpoint:
default: ""
ai_hugging_face_tei_reranker_endpoint_srv:
default: ""
hidden: true
ai_hugging_face_tei_reranker_api_key: ""
ai_google_custom_search_api_key:
default: ""
secret: true
@ -232,7 +238,6 @@ discourse_ai:
- "llava"
- "open_ai:gpt-4-vision-preview"
ai_embeddings_enabled:
default: false
client: true
@ -282,6 +287,10 @@ discourse_ai:
allow_any: false
enum: "DiscourseAi::Configuration::LlmEnumerator"
validator: "DiscourseAi::Configuration::LlmValidator"
ai_embeddings_semantic_quick_search_enabled:
default: false
client: true
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_endpoint_srv:

View File

@ -82,6 +82,75 @@ module DiscourseAi
guardian.filter_allowed_categories(query_filter_results)
end
def quick_search(query)
max_semantic_results_per_page = 100
search = Search.new(query, { guardian: guardian })
search_term = search.term
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key =
build_embedding_key(
digest,
SiteSetting.ai_embeddings_semantic_search_hyde_model,
SiteSetting.ai_embeddings_model,
)
search_term_embedding =
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true)
end
candidate_post_ids =
vector_rep.asymmetric_posts_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,
offset: 0,
)
semantic_results =
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(id: candidate_post_ids)
.order("array_position(ARRAY#{candidate_post_ids}, posts.id)")
filtered_results = search.apply_filters(semantic_results)
rerank_posts_payload =
filtered_results
.map(&:cooked)
.map { Nokogiri::HTML5.fragment(_1).text }
.map { _1.truncate(2000, omission: "") }
reranked_results =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
search_term,
rerank_posts_payload,
)
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
reranked_semantic_results =
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(id: reordered_ids)
.order("array_position(ARRAY#{reordered_ids}, posts.id)")
guardian.filter_allowed_categories(reranked_semantic_results)
end
private
attr_reader :guardian

View File

@ -23,7 +23,7 @@ module DiscourseAi
end
end
def vector_from(text)
def vector_from(text, asymetric: false)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
self.class.name,

View File

@ -54,6 +54,7 @@ module DiscourseAi
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
Discourse.cache.write("#{table_name}-probes", probes)
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
SELECT
@ -144,18 +145,9 @@ module DiscourseAi
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
DB.exec("RESET work_mem;")
DB.exec("RESET maintenance_work_mem;")
database = DB.query_single("SELECT current_database();").first
# This is a global setting, if we set it based on post count
# we will be unable to use the index for topics
# Hopefully https://github.com/pgvector/pgvector/issues/235 will make this better
if table_name == topic_table_name
DB.exec("ALTER DATABASE #{database} SET ivfflat.probes = #{probes};")
end
end
def vector_from(text)
def vector_from(text, asymetric: false)
raise NotImplementedError
end
@ -206,6 +198,7 @@ module DiscourseAi
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
#{probes_sql(topic_table_name)}
SELECT
topic_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
@ -227,8 +220,37 @@ module DiscourseAi
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
#{probes_sql(post_table_name)}
SELECT
post_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{post_table_name}
INNER JOIN
posts AS p ON p.id = post_id
INNER JOIN
topics AS t ON t.id = p.topic_id AND t.archetype = 'regular'
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
#{probes_sql(topic_table_name)}
SELECT
topic_id
FROM
@ -275,6 +297,11 @@ module DiscourseAi
"#{table_name}_search"
end
def probes_sql(table_name)
probes = Discourse.cache.read("#{table_name}-probes")
probes.present? ? "SET LOCAL ivfflat.probes TO #{probes};" : ""
end
def name
raise NotImplementedError
end
@ -303,6 +330,10 @@ module DiscourseAi
raise NotImplementedError
end
def asymmetric_query_prefix
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)

View File

@ -30,7 +30,9 @@ module DiscourseAi
end
end
def vector_from(text)
def vector_from(text, asymetric: false)
text = "#{asymmetric_query_prefix} #{text}" if asymetric
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
@ -82,6 +84,10 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
end
def asymmetric_query_prefix
"Represent this sentence for searching relevant passages:"
end
end
end
end

View File

@ -42,7 +42,7 @@ module DiscourseAi
"vector_cosine_ops"
end
def vector_from(text)
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
response[:embedding][:values]
end

View File

@ -28,7 +28,7 @@ module DiscourseAi
end
end
def vector_from(text)
def vector_from(text, asymetric: false)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first

View File

@ -44,7 +44,7 @@ module DiscourseAi
"vector_cosine_ops"
end
def vector_from(text)
def vector_from(text, asymetric: false)
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,

View File

@ -42,7 +42,7 @@ module DiscourseAi
"vector_cosine_ops"
end
def vector_from(text)
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -42,7 +42,7 @@ module DiscourseAi
"vector_cosine_ops"
end
def vector_from(text)
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -3,32 +3,63 @@
module ::DiscourseAi
module Inference
class HuggingFaceTextEmbeddings
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
class << self
def perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
def rerank(content, candidates)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { query: content, texts: candidates, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv,
)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_reranker_endpoint
end
if SiteSetting.ai_hugging_face_tei_reranker_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_reranker_api_key
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post("#{api_endpoint}/rerank", body, headers)
if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
JSON.parse(response.body, symbolize_names: true)
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
def self.configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
end
end
end