FEATURE: AI Quick Semantic Search (#501)
This PR adds AI semantic search to the search pop available on every page. It depends on several new and optional settings, like per post embeddings and a reranker model, so this is an experimental endeavour. --------- Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
parent
114b96f2b4
commit
b515b4f66d
|
@ -36,6 +36,34 @@ module DiscourseAi
|
||||||
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def quick_search
|
||||||
|
query = params[:q].to_s
|
||||||
|
|
||||||
|
if query.length < SiteSetting.min_search_term_length
|
||||||
|
raise Discourse::InvalidParameters.new(:q)
|
||||||
|
end
|
||||||
|
|
||||||
|
grouped_results =
|
||||||
|
Search::GroupedSearchResults.new(
|
||||||
|
type_filter: SEMANTIC_SEARCH_TYPE,
|
||||||
|
term: query,
|
||||||
|
search_context: guardian,
|
||||||
|
use_pg_headlines_for_excerpt: false,
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic_search = DiscourseAi::Embeddings::SemanticSearch.new(guardian)
|
||||||
|
|
||||||
|
if !semantic_search.cached_query?(query)
|
||||||
|
RateLimiter.new(current_user, "semantic-search", 60, 1.minutes).performed!
|
||||||
|
end
|
||||||
|
|
||||||
|
hijack do
|
||||||
|
semantic_search.quick_search(query).each { |topic_post| grouped_results.add(topic_post) }
|
||||||
|
|
||||||
|
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { inject as service } from "@ember/service";
|
||||||
|
import loadingSpinner from "discourse/helpers/loading-spinner";
|
||||||
|
|
||||||
|
export default class AiQuickSearchLoader extends Component {
|
||||||
|
@service quickSearch;
|
||||||
|
|
||||||
|
<template>
|
||||||
|
{{#if this.quickSearch.loading}}
|
||||||
|
<div class="ai-quick-search-spinner">
|
||||||
|
{{loadingSpinner}}
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
|
</template>
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { action } from "@ember/object";
|
||||||
|
import { inject as service } from "@ember/service";
|
||||||
|
import AssistantItem from "discourse/components/search-menu/results/assistant-item";
|
||||||
|
import { ajax } from "discourse/lib/ajax";
|
||||||
|
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||||
|
import { isValidSearchTerm, translateResults } from "discourse/lib/search";
|
||||||
|
import i18n from "discourse-common/helpers/i18n";
|
||||||
|
|
||||||
|
export default class AiQuickSemanticSearch extends Component {
|
||||||
|
static shouldRender(_args, { siteSettings }) {
|
||||||
|
return siteSettings.ai_embeddings_semantic_quick_search_enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
@service search;
|
||||||
|
@service quickSearch;
|
||||||
|
@service siteSettings;
|
||||||
|
|
||||||
|
@action
|
||||||
|
async searchTermChanged() {
|
||||||
|
if (!this.search.activeGlobalSearchTerm) {
|
||||||
|
this.search.noResults = false;
|
||||||
|
this.search.results = {};
|
||||||
|
this.quickSearch.loading = false;
|
||||||
|
this.quickSearch.invalidTerm = false;
|
||||||
|
} else if (
|
||||||
|
!isValidSearchTerm(this.search.activeGlobalSearchTerm, this.siteSettings)
|
||||||
|
) {
|
||||||
|
this.search.noResults = true;
|
||||||
|
this.search.results = {};
|
||||||
|
this.quickSearch.loading = false;
|
||||||
|
this.quickSearch.invalidTerm = true;
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
await this.performSearch();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async performSearch() {
|
||||||
|
this.quickSearch.loading = true;
|
||||||
|
this.quickSearch.invalidTerm = false;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const results = await ajax(`/discourse-ai/embeddings/quick-search`, {
|
||||||
|
data: {
|
||||||
|
q: this.search.activeGlobalSearchTerm,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const searchResults = await translateResults(results);
|
||||||
|
|
||||||
|
if (searchResults) {
|
||||||
|
this.search.noResults = results.resultTypes.length === 0;
|
||||||
|
this.search.results = searchResults;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
popupAjaxError(error);
|
||||||
|
} finally {
|
||||||
|
this.quickSearch.loading = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
<template>
|
||||||
|
{{yield}}
|
||||||
|
|
||||||
|
{{#if this.search.activeGlobalSearchTerm}}
|
||||||
|
<AssistantItem
|
||||||
|
@suffix={{i18n "discourse_ai.embeddings.quick_search.suffix"}}
|
||||||
|
@icon="discourse-sparkles"
|
||||||
|
@closeSearchMenu={{@closeSearchMenu}}
|
||||||
|
@searchTermChanged={{this.searchTermChanged}}
|
||||||
|
@suggestionKeyword={{@suggestionKeyword}}
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
|
</template>
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { inject as service } from "@ember/service";
|
||||||
|
import { isValidSearchTerm } from "discourse/lib/search";
|
||||||
|
import i18n from "discourse-common/helpers/i18n";
|
||||||
|
|
||||||
|
export default class AiQuickSearchInfo extends Component {
|
||||||
|
@service search;
|
||||||
|
@service siteSettings;
|
||||||
|
@service quickSearch;
|
||||||
|
|
||||||
|
get termTooShort() {
|
||||||
|
// We check the validity again here because the input may have changed
|
||||||
|
// since the last time we checked, so we may want to stop showing the error
|
||||||
|
const validity = !isValidSearchTerm(
|
||||||
|
this.search.activeGlobalSearchTerm,
|
||||||
|
this.siteSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
validity &&
|
||||||
|
this.quickSearch.invalidTerm &&
|
||||||
|
this.search.activeGlobalSearchTerm?.length > 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
<template>
|
||||||
|
{{#if this.termTooShort}}
|
||||||
|
<div class="no-results">{{i18n "search.too_short"}}</div>
|
||||||
|
{{/if}}
|
||||||
|
</template>
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
import { tracked } from "@glimmer/tracking";
|
||||||
|
import Service from "@ember/service";
|
||||||
|
|
||||||
|
export default class QuickSearch extends Service {
|
||||||
|
@tracked loading = false;
|
||||||
|
@tracked invalidTerm = false;
|
||||||
|
}
|
|
@ -72,3 +72,12 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hides other buttons and only shows loader
|
||||||
|
// while AI quick search is in progress
|
||||||
|
.search-input {
|
||||||
|
.ai-quick-search-spinner ~ a.clear-search,
|
||||||
|
.ai-quick-search-spinner ~ a.show-advanced-search {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -211,6 +211,8 @@ en:
|
||||||
none: "Sorry, our AI search found no matching topics."
|
none: "Sorry, our AI search found no matching topics."
|
||||||
new: "Press 'Search' to begin looking for new results with AI"
|
new: "Press 'Search' to begin looking for new results with AI"
|
||||||
ai_generated_result: "Search result found using AI"
|
ai_generated_result: "Search result found using AI"
|
||||||
|
quick_search:
|
||||||
|
suffix: "in all topics and posts with AI"
|
||||||
|
|
||||||
ai_bot:
|
ai_bot:
|
||||||
pm_warning: "AI chatbot messages are monitored regularly by moderators."
|
pm_warning: "AI chatbot messages are monitored regularly by moderators."
|
||||||
|
|
|
@ -84,6 +84,7 @@ en:
|
||||||
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
|
ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes."
|
||||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||||
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
ai_embeddings_semantic_search_enabled: "Enable full-page semantic search."
|
||||||
|
ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup."
|
||||||
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results"
|
||||||
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
|
ai_embeddings_semantic_search_hyde_model: "Model used to expand keywords to get better results during a semantic search"
|
||||||
ai_embeddings_per_post_enabled: Generate embeddings for each post
|
ai_embeddings_per_post_enabled: Generate embeddings for each post
|
||||||
|
|
|
@ -12,6 +12,7 @@ DiscourseAi::Engine.routes.draw do
|
||||||
|
|
||||||
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
||||||
get "semantic-search" => "embeddings#search"
|
get "semantic-search" => "embeddings#search"
|
||||||
|
get "quick-search" => "embeddings#quick_search"
|
||||||
end
|
end
|
||||||
|
|
||||||
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
||||||
|
|
|
@ -137,6 +137,12 @@ discourse_ai:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_hugging_face_tei_api_key: ""
|
ai_hugging_face_tei_api_key: ""
|
||||||
|
ai_hugging_face_tei_reranker_endpoint:
|
||||||
|
default: ""
|
||||||
|
ai_hugging_face_tei_reranker_endpoint_srv:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
|
ai_hugging_face_tei_reranker_api_key: ""
|
||||||
ai_google_custom_search_api_key:
|
ai_google_custom_search_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
@ -232,7 +238,6 @@ discourse_ai:
|
||||||
- "llava"
|
- "llava"
|
||||||
- "open_ai:gpt-4-vision-preview"
|
- "open_ai:gpt-4-vision-preview"
|
||||||
|
|
||||||
|
|
||||||
ai_embeddings_enabled:
|
ai_embeddings_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
@ -282,6 +287,10 @@ discourse_ai:
|
||||||
allow_any: false
|
allow_any: false
|
||||||
enum: "DiscourseAi::Configuration::LlmEnumerator"
|
enum: "DiscourseAi::Configuration::LlmEnumerator"
|
||||||
validator: "DiscourseAi::Configuration::LlmValidator"
|
validator: "DiscourseAi::Configuration::LlmValidator"
|
||||||
|
ai_embeddings_semantic_quick_search_enabled:
|
||||||
|
default: false
|
||||||
|
client: true
|
||||||
|
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: ""
|
ai_summarization_discourse_service_api_endpoint: ""
|
||||||
ai_summarization_discourse_service_api_endpoint_srv:
|
ai_summarization_discourse_service_api_endpoint_srv:
|
||||||
|
|
|
@ -82,6 +82,75 @@ module DiscourseAi
|
||||||
guardian.filter_allowed_categories(query_filter_results)
|
guardian.filter_allowed_categories(query_filter_results)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def quick_search(query)
|
||||||
|
max_semantic_results_per_page = 100
|
||||||
|
search = Search.new(query, { guardian: guardian })
|
||||||
|
search_term = search.term
|
||||||
|
|
||||||
|
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||||
|
|
||||||
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
|
||||||
|
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||||
|
|
||||||
|
embedding_key =
|
||||||
|
build_embedding_key(
|
||||||
|
digest,
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
|
SiteSetting.ai_embeddings_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_term_embedding =
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch(embedding_key, expires_in: 1.week) do
|
||||||
|
vector_rep.vector_from(search_term, asymetric: true)
|
||||||
|
end
|
||||||
|
|
||||||
|
candidate_post_ids =
|
||||||
|
vector_rep.asymmetric_posts_similarity_search(
|
||||||
|
search_term_embedding,
|
||||||
|
limit: max_semantic_results_per_page,
|
||||||
|
offset: 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic_results =
|
||||||
|
::Post
|
||||||
|
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||||
|
.public_posts
|
||||||
|
.where("topics.visible")
|
||||||
|
.where(id: candidate_post_ids)
|
||||||
|
.order("array_position(ARRAY#{candidate_post_ids}, posts.id)")
|
||||||
|
|
||||||
|
filtered_results = search.apply_filters(semantic_results)
|
||||||
|
|
||||||
|
rerank_posts_payload =
|
||||||
|
filtered_results
|
||||||
|
.map(&:cooked)
|
||||||
|
.map { Nokogiri::HTML5.fragment(_1).text }
|
||||||
|
.map { _1.truncate(2000, omission: "") }
|
||||||
|
|
||||||
|
reranked_results =
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
|
||||||
|
search_term,
|
||||||
|
rerank_posts_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
|
||||||
|
|
||||||
|
reranked_semantic_results =
|
||||||
|
::Post
|
||||||
|
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||||
|
.public_posts
|
||||||
|
.where("topics.visible")
|
||||||
|
.where(id: reordered_ids)
|
||||||
|
.order("array_position(ARRAY#{reordered_ids}, posts.id)")
|
||||||
|
|
||||||
|
guardian.filter_allowed_categories(reranked_semantic_results)
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :guardian
|
attr_reader :guardian
|
||||||
|
|
|
@ -23,7 +23,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||||
self.class.name,
|
self.class.name,
|
||||||
|
|
|
@ -54,6 +54,7 @@ module DiscourseAi
|
||||||
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
|
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
|
||||||
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
|
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
|
||||||
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
|
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
|
||||||
|
Discourse.cache.write("#{table_name}-probes", probes)
|
||||||
|
|
||||||
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
|
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -144,18 +145,9 @@ module DiscourseAi
|
||||||
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
|
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
|
||||||
DB.exec("RESET work_mem;")
|
DB.exec("RESET work_mem;")
|
||||||
DB.exec("RESET maintenance_work_mem;")
|
DB.exec("RESET maintenance_work_mem;")
|
||||||
|
|
||||||
database = DB.query_single("SELECT current_database();").first
|
|
||||||
|
|
||||||
# This is a global setting, if we set it based on post count
|
|
||||||
# we will be unable to use the index for topics
|
|
||||||
# Hopefully https://github.com/pgvector/pgvector/issues/235 will make this better
|
|
||||||
if table_name == topic_table_name
|
|
||||||
DB.exec("ALTER DATABASE #{database} SET ivfflat.probes = #{probes};")
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -206,6 +198,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
||||||
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
||||||
|
#{probes_sql(topic_table_name)}
|
||||||
SELECT
|
SELECT
|
||||||
topic_id,
|
topic_id,
|
||||||
embeddings #{pg_function} '[:query_embedding]' AS distance
|
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||||
|
@ -227,8 +220,37 @@ module DiscourseAi
|
||||||
raise MissingEmbeddingError
|
raise MissingEmbeddingError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
||||||
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
||||||
|
#{probes_sql(post_table_name)}
|
||||||
|
SELECT
|
||||||
|
post_id,
|
||||||
|
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||||
|
FROM
|
||||||
|
#{post_table_name}
|
||||||
|
INNER JOIN
|
||||||
|
posts AS p ON p.id = post_id
|
||||||
|
INNER JOIN
|
||||||
|
topics AS t ON t.id = p.topic_id AND t.archetype = 'regular'
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset
|
||||||
|
SQL
|
||||||
|
|
||||||
|
if return_distance
|
||||||
|
results.map { |r| [r.post_id, r.distance] }
|
||||||
|
else
|
||||||
|
results.map(&:post_id)
|
||||||
|
end
|
||||||
|
rescue PG::Error => e
|
||||||
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||||
|
raise MissingEmbeddingError
|
||||||
|
end
|
||||||
|
|
||||||
def symmetric_topics_similarity_search(topic)
|
def symmetric_topics_similarity_search(topic)
|
||||||
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||||
|
#{probes_sql(topic_table_name)}
|
||||||
SELECT
|
SELECT
|
||||||
topic_id
|
topic_id
|
||||||
FROM
|
FROM
|
||||||
|
@ -275,6 +297,11 @@ module DiscourseAi
|
||||||
"#{table_name}_search"
|
"#{table_name}_search"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def probes_sql(table_name)
|
||||||
|
probes = Discourse.cache.read("#{table_name}-probes")
|
||||||
|
probes.present? ? "SET LOCAL ivfflat.probes TO #{probes};" : ""
|
||||||
|
end
|
||||||
|
|
||||||
def name
|
def name
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
@ -303,6 +330,10 @@ module DiscourseAi
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def asymmetric_query_prefix
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def save_to_db(target, vector, digest)
|
def save_to_db(target, vector, digest)
|
||||||
|
|
|
@ -30,7 +30,9 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
|
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
||||||
|
|
||||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||||
DiscourseAi::Inference::CloudflareWorkersAi
|
DiscourseAi::Inference::CloudflareWorkersAi
|
||||||
.perform!(inference_model_name, { text: text })
|
.perform!(inference_model_name, { text: text })
|
||||||
|
@ -82,6 +84,10 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def asymmetric_query_prefix
|
||||||
|
"Represent this sentence for searching relevant passages:"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -42,7 +42,7 @@ module DiscourseAi
|
||||||
"vector_cosine_ops"
|
"vector_cosine_ops"
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
|
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
|
||||||
response[:embedding][:values]
|
response[:embedding][:values]
|
||||||
end
|
end
|
||||||
|
|
|
@ -28,7 +28,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||||
|
|
|
@ -44,7 +44,7 @@ module DiscourseAi
|
||||||
"vector_cosine_ops"
|
"vector_cosine_ops"
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
response =
|
response =
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||||
text,
|
text,
|
||||||
|
|
|
@ -42,7 +42,7 @@ module DiscourseAi
|
||||||
"vector_cosine_ops"
|
"vector_cosine_ops"
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||||
response[:data].first[:embedding]
|
response[:data].first[:embedding]
|
||||||
end
|
end
|
||||||
|
|
|
@ -42,7 +42,7 @@ module DiscourseAi
|
||||||
"vector_cosine_ops"
|
"vector_cosine_ops"
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||||
response[:data].first[:embedding]
|
response[:data].first[:embedding]
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,32 +3,63 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class HuggingFaceTextEmbeddings
|
class HuggingFaceTextEmbeddings
|
||||||
def self.perform!(content)
|
class << self
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
def perform!(content)
|
||||||
body = { inputs: content, truncate: true }.to_json
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
body = { inputs: content, truncate: true }.to_json
|
||||||
|
|
||||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
service =
|
||||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||||
else
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||||
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
else
|
||||||
|
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||||
|
end
|
||||||
|
|
||||||
|
if SiteSetting.ai_hugging_face_tei_api_key.present?
|
||||||
|
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
|
||||||
|
end
|
||||||
|
|
||||||
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
response = conn.post(api_endpoint, body, headers)
|
||||||
|
|
||||||
|
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
if SiteSetting.ai_hugging_face_tei_api_key.present?
|
def rerank(content, candidates)
|
||||||
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
body = { query: content, texts: candidates, truncate: true }.to_json
|
||||||
|
|
||||||
|
if SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||||
|
service =
|
||||||
|
DiscourseAi::Utils::DnsSrv.lookup(
|
||||||
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv,
|
||||||
|
)
|
||||||
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||||
|
else
|
||||||
|
api_endpoint = SiteSetting.ai_hugging_face_tei_reranker_endpoint
|
||||||
|
end
|
||||||
|
|
||||||
|
if SiteSetting.ai_hugging_face_tei_reranker_api_key.present?
|
||||||
|
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_reranker_api_key
|
||||||
|
end
|
||||||
|
|
||||||
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
response = conn.post("#{api_endpoint}/rerank", body, headers)
|
||||||
|
|
||||||
|
if response.status != 200
|
||||||
|
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
|
||||||
|
end
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
def configured?
|
||||||
response = conn.post(api_endpoint, body, headers)
|
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||||
|
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
end
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.configured?
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue