FEATURE: Semantic assymetric full-page search (#34)
Depends on discourse/discourse#20915 Hooks to the full-page-search component using an experimental API and performs an assymetric similarity search using our embeddings database.
This commit is contained in:
parent
99886fb64d
commit
4e05763a99
|
@ -0,0 +1,35 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class EmbeddingsController < ::ApplicationController
|
||||||
|
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||||
|
|
||||||
|
SEMANTIC_SEARCH_TYPE = "semantic_search"
|
||||||
|
|
||||||
|
def search
|
||||||
|
query = params[:q]
|
||||||
|
page = (params[:page] || 1).to_i
|
||||||
|
|
||||||
|
grouped_results =
|
||||||
|
Search::GroupedSearchResults.new(
|
||||||
|
type_filter: SEMANTIC_SEARCH_TYPE,
|
||||||
|
term: query,
|
||||||
|
search_context: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
model =
|
||||||
|
DiscourseAi::Embeddings::Model.instantiate(
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::SemanticSearch
|
||||||
|
.new(guardian, model)
|
||||||
|
.search_for_topics(query, page)
|
||||||
|
.each { |topic_post| grouped_results.add(topic_post) }
|
||||||
|
|
||||||
|
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,64 @@
|
||||||
|
import { withPluginApi } from "discourse/lib/plugin-api";
|
||||||
|
import { translateResults, updateRecentSearches } from "discourse/lib/search";
|
||||||
|
import { setTransient } from "discourse/lib/page-tracker";
|
||||||
|
import { ajax } from "discourse/lib/ajax";
|
||||||
|
|
||||||
|
const SEMANTIC_SEARCH = "semantic_search";
|
||||||
|
|
||||||
|
function initializeSemanticSearch(api) {
|
||||||
|
api.addFullPageSearchType(
|
||||||
|
"discourse_ai.embeddings.semantic_search",
|
||||||
|
SEMANTIC_SEARCH,
|
||||||
|
(searchController, args, searchKey) => {
|
||||||
|
if (searchController.currentUser) {
|
||||||
|
updateRecentSearches(searchController.currentUser, args.searchTerm);
|
||||||
|
}
|
||||||
|
|
||||||
|
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
|
||||||
|
.then(async (results) => {
|
||||||
|
const model = (await translateResults(results)) || {};
|
||||||
|
|
||||||
|
if (results.grouped_search_result) {
|
||||||
|
searchController.set("q", results.grouped_search_result.term);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.page > 1) {
|
||||||
|
if (model) {
|
||||||
|
searchController.model.posts.pushObjects(model.posts);
|
||||||
|
searchController.model.topics.pushObjects(model.topics);
|
||||||
|
searchController.model.set(
|
||||||
|
"grouped_search_result",
|
||||||
|
results.grouped_search_result
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setTransient("lastSearch", { searchKey, model }, 5);
|
||||||
|
model.grouped_search_result = results.grouped_search_result;
|
||||||
|
searchController.set("model", model);
|
||||||
|
}
|
||||||
|
searchController.set("error", null);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
searchController.set("error", e.jqXHR.responseJSON?.message);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
searchController.setProperties({
|
||||||
|
searching: false,
|
||||||
|
loading: false,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default {
|
||||||
|
name: "discourse_ai-full-page-semantic-search",
|
||||||
|
|
||||||
|
initialize(container) {
|
||||||
|
const settings = container.lookup("site-settings:main");
|
||||||
|
|
||||||
|
if (settings.ai_embeddings_enabled) {
|
||||||
|
withPluginApi("1.6.0", initializeSemanticSearch);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
|
@ -10,6 +10,9 @@ en:
|
||||||
reviewables:
|
reviewables:
|
||||||
model_used: "Model used:"
|
model_used: "Model used:"
|
||||||
accuracy: "Accuracy:"
|
accuracy: "Accuracy:"
|
||||||
|
|
||||||
|
embeddings:
|
||||||
|
semantic_search: "Topics (Semantic)"
|
||||||
review:
|
review:
|
||||||
types:
|
types:
|
||||||
reviewable_ai_post:
|
reviewable_ai_post:
|
||||||
|
|
|
@ -47,6 +47,7 @@ en:
|
||||||
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
|
||||||
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
|
||||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||||
|
ai_embeddings_semantic_search_model: "Model to use for semantic search."
|
||||||
|
|
||||||
reviewables:
|
reviewables:
|
||||||
reasons:
|
reasons:
|
||||||
|
|
|
@ -6,6 +6,11 @@ DiscourseAi::Engine.routes.draw do
|
||||||
get "prompts" => "assistant#prompts"
|
get "prompts" => "assistant#prompts"
|
||||||
post "suggest" => "assistant#suggest"
|
post "suggest" => "assistant#suggest"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Embedding routes
|
||||||
|
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
|
||||||
|
get "semantic-search" => "embeddings#search"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" }
|
Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" }
|
||||||
|
|
|
@ -108,7 +108,9 @@ plugins:
|
||||||
- gpt-3.5-turbo
|
- gpt-3.5-turbo
|
||||||
- gpt-4
|
- gpt-4
|
||||||
|
|
||||||
ai_embeddings_enabled: false
|
ai_embeddings_enabled:
|
||||||
|
default: false
|
||||||
|
client: true
|
||||||
ai_embeddings_discourse_service_api_endpoint: ""
|
ai_embeddings_discourse_service_api_endpoint: ""
|
||||||
ai_embeddings_discourse_service_api_key: ""
|
ai_embeddings_discourse_service_api_key: ""
|
||||||
ai_embeddings_models:
|
ai_embeddings_models:
|
||||||
|
@ -133,6 +135,13 @@ plugins:
|
||||||
- all-distilroberta-v1
|
- all-distilroberta-v1
|
||||||
- multi-qa-mpnet-base-dot-v1
|
- multi-qa-mpnet-base-dot-v1
|
||||||
- paraphrase-multilingual-mpnet-base-v2
|
- paraphrase-multilingual-mpnet-base-v2
|
||||||
|
ai_embeddings_semantic_search_model:
|
||||||
|
type: enum
|
||||||
|
default: msmarco-distilbert-base-v4
|
||||||
|
choices:
|
||||||
|
- msmarco-distilbert-base-v4
|
||||||
|
- msmarco-distilbert-base-tas-b
|
||||||
|
- text-embedding-ada-002
|
||||||
ai_embeddings_generate_for_pms: false
|
ai_embeddings_generate_for_pms: false
|
||||||
ai_embeddings_semantic_related_topics_enabled: false
|
ai_embeddings_semantic_related_topics_enabled: false
|
||||||
ai_embeddings_semantic_related_topics: 5
|
ai_embeddings_semantic_related_topics: 5
|
||||||
|
|
|
@ -4,10 +4,11 @@ module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
def load_files
|
||||||
require_relative "models"
|
require_relative "model"
|
||||||
require_relative "topic"
|
require_relative "topic"
|
||||||
require_relative "jobs/regular/generate_embeddings"
|
require_relative "jobs/regular/generate_embeddings"
|
||||||
require_relative "semantic_related"
|
require_relative "semantic_related"
|
||||||
|
require_relative "semantic_search"
|
||||||
end
|
end
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class Model
|
||||||
|
AVAILABLE_MODELS_TEMPLATES = {
|
||||||
|
"all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
||||||
|
"all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
|
||||||
|
"multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
|
||||||
|
"paraphrase-multilingual-mpnet-base-v2" => [
|
||||||
|
768,
|
||||||
|
128,
|
||||||
|
%i[cosine],
|
||||||
|
%i[symmetric],
|
||||||
|
"discourse",
|
||||||
|
],
|
||||||
|
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
|
||||||
|
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
|
||||||
|
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
|
||||||
|
}
|
||||||
|
|
||||||
|
SEARCH_FUNCTION_TO_PG_INDEX = {
|
||||||
|
dot: "vector_ip_ops",
|
||||||
|
cosine: "vector_cosine_ops",
|
||||||
|
euclidean: "vector_l2_ops",
|
||||||
|
}
|
||||||
|
|
||||||
|
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
||||||
|
|
||||||
|
class << self
|
||||||
|
def instantiate(model_name)
|
||||||
|
new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
|
||||||
|
end
|
||||||
|
|
||||||
|
def enabled_models
|
||||||
|
SiteSetting
|
||||||
|
.ai_embeddings_models
|
||||||
|
.split("|")
|
||||||
|
.map { |model_name| instantiate(model_name.strip) }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
|
||||||
|
@name = name
|
||||||
|
@dimensions = dimensions
|
||||||
|
@max_sequence_lenght = max_sequence_lenght
|
||||||
|
@functions = functions
|
||||||
|
@type = type
|
||||||
|
@provider = provider
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_embedding(input)
|
||||||
|
send("#{provider}_embeddings", input)
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index
|
||||||
|
SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def discourse_embeddings(input)
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
name.to_s,
|
||||||
|
input,
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def openai_embeddings(input)
|
||||||
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,62 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Embeddings
|
|
||||||
class Models
|
|
||||||
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
|
|
||||||
|
|
||||||
SEARCH_FUNCTION_TO_PG_INDEX = {
|
|
||||||
dot: "vector_ip_ops",
|
|
||||||
cosine: "vector_cosine_ops",
|
|
||||||
euclidean: "vector_l2_ops",
|
|
||||||
}
|
|
||||||
|
|
||||||
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
|
||||||
|
|
||||||
def self.enabled_models
|
|
||||||
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
|
|
||||||
list.filter { |model| setting.include?(model.name) }
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.list
|
|
||||||
@@list ||= [
|
|
||||||
MODEL.new(
|
|
||||||
"all-mpnet-base-v2",
|
|
||||||
768,
|
|
||||||
384,
|
|
||||||
%i[dot cosine euclidean],
|
|
||||||
[:symmetric],
|
|
||||||
"discourse",
|
|
||||||
),
|
|
||||||
MODEL.new(
|
|
||||||
"all-distilroberta-v1",
|
|
||||||
768,
|
|
||||||
512,
|
|
||||||
%i[dot cosine euclidean],
|
|
||||||
[:symmetric],
|
|
||||||
"discourse",
|
|
||||||
),
|
|
||||||
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
|
|
||||||
MODEL.new(
|
|
||||||
"paraphrase-multilingual-mpnet-base-v2",
|
|
||||||
768,
|
|
||||||
128,
|
|
||||||
[:cosine],
|
|
||||||
[:symmetric],
|
|
||||||
"discourse",
|
|
||||||
),
|
|
||||||
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
|
|
||||||
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
|
|
||||||
MODEL.new(
|
|
||||||
"text-embedding-ada-002",
|
|
||||||
1536,
|
|
||||||
2048,
|
|
||||||
[:cosine],
|
|
||||||
%i[:symmetric :asymmetric],
|
|
||||||
"openai",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -16,12 +16,17 @@ module DiscourseAi
|
||||||
1.day
|
1.day
|
||||||
end
|
end
|
||||||
|
|
||||||
|
model =
|
||||||
|
DiscourseAi::Embeddings::Model.instantiate(
|
||||||
|
SiteSetting.ai_embeddings_semantic_related_model,
|
||||||
|
)
|
||||||
|
|
||||||
begin
|
begin
|
||||||
candidate_ids =
|
candidate_ids =
|
||||||
Discourse
|
Discourse
|
||||||
.cache
|
.cache
|
||||||
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
|
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
|
||||||
search_suggestions(topic)
|
DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic)
|
||||||
end
|
end
|
||||||
rescue StandardError => e
|
rescue StandardError => e
|
||||||
Rails.logger.error("SemanticRelated: #{e}")
|
Rails.logger.error("SemanticRelated: #{e}")
|
||||||
|
@ -39,40 +44,6 @@ module DiscourseAi
|
||||||
.order("array_position(ARRAY#{candidate_ids}, id)")
|
.order("array_position(ARRAY#{candidate_ids}, id)")
|
||||||
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
|
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.search_suggestions(topic)
|
|
||||||
model_name = SiteSetting.ai_embeddings_semantic_related_model
|
|
||||||
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
|
|
||||||
function =
|
|
||||||
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
|
|
||||||
|
|
||||||
candidate_ids =
|
|
||||||
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
|
||||||
SELECT
|
|
||||||
topic_id
|
|
||||||
FROM
|
|
||||||
topic_embeddings_#{model_name.underscore}
|
|
||||||
ORDER BY
|
|
||||||
embedding #{function} (
|
|
||||||
SELECT
|
|
||||||
embedding
|
|
||||||
FROM
|
|
||||||
topic_embeddings_#{model_name.underscore}
|
|
||||||
WHERE
|
|
||||||
topic_id = :topic_id
|
|
||||||
LIMIT 1
|
|
||||||
)
|
|
||||||
LIMIT 11
|
|
||||||
SQL
|
|
||||||
|
|
||||||
# Happens when the topic doesn't have any embeddings
|
|
||||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
|
||||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
|
||||||
raise StandardError, "No embeddings found for topic #{topic.id}"
|
|
||||||
end
|
|
||||||
|
|
||||||
candidate_ids
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class SemanticSearch
|
||||||
|
def initialize(guardian, model)
|
||||||
|
@guardian = guardian
|
||||||
|
@model = model
|
||||||
|
end
|
||||||
|
|
||||||
|
def search_for_topics(query, page = 1)
|
||||||
|
limit = Search.per_filter + 1
|
||||||
|
offset = (page - 1) * Search.per_filter
|
||||||
|
|
||||||
|
candidate_ids =
|
||||||
|
DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset)
|
||||||
|
|
||||||
|
::Post
|
||||||
|
.where(post_type: ::Topic.visible_post_types(guardian.user))
|
||||||
|
.public_posts
|
||||||
|
.where("topics.visible")
|
||||||
|
.where(topic_id: candidate_ids, post_number: 1)
|
||||||
|
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
attr_reader :model, :guardian
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -3,54 +3,80 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class Topic
|
class Topic
|
||||||
def initialize(topic)
|
def generate_and_store_embeddings_for(topic)
|
||||||
@topic = topic
|
|
||||||
@embeddings = {}
|
|
||||||
end
|
|
||||||
|
|
||||||
def perform!
|
|
||||||
return unless SiteSetting.ai_embeddings_enabled
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
|
return if topic.blank? || topic.first_post.blank?
|
||||||
|
|
||||||
calculate_embeddings!
|
enabled_models = DiscourseAi::Embeddings::Model.enabled_models
|
||||||
persist_embeddings! unless @embeddings.empty?
|
return if enabled_models.empty?
|
||||||
end
|
|
||||||
|
|
||||||
def calculate_embeddings!
|
enabled_models.each do |model|
|
||||||
return if @topic.blank? || @topic.first_post.blank?
|
embedding = model.generate_embedding(topic.first_post.raw)
|
||||||
|
persist_embedding(topic, model, embedding) if embedding
|
||||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
|
||||||
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def persist_embeddings!
|
def symmetric_semantic_search(model, topic)
|
||||||
@embeddings.each do |model, model_embedding|
|
candidate_ids =
|
||||||
DiscourseAi::Database::Connection.db.exec(
|
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||||
<<~SQL,
|
SELECT
|
||||||
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
|
topic_id
|
||||||
VALUES (:topic_id, '[:embedding]')
|
FROM
|
||||||
ON CONFLICT (topic_id)
|
topic_embeddings_#{model.name.underscore}
|
||||||
DO UPDATE SET embedding = '[:embedding]'
|
ORDER BY
|
||||||
SQL
|
embedding #{model.pg_function} (
|
||||||
topic_id: @topic.id,
|
SELECT
|
||||||
embedding: model_embedding,
|
embedding
|
||||||
)
|
FROM
|
||||||
|
topic_embeddings_#{model.name.underscore}
|
||||||
|
WHERE
|
||||||
|
topic_id = :topic_id
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
LIMIT 11
|
||||||
|
SQL
|
||||||
|
|
||||||
|
# Happens when the topic doesn't have any embeddings
|
||||||
|
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||||
|
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||||
|
raise StandardError, "No embeddings found for topic #{topic.id}"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
candidate_ids
|
||||||
end
|
end
|
||||||
|
|
||||||
def discourse_embeddings(model)
|
def asymmetric_semantic_search(model, query, limit, offset)
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
query_embedding = model.generate_embedding(query)
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
|
||||||
model.to_s,
|
candidate_ids =
|
||||||
@topic.first_post.raw,
|
DiscourseAi::Database::Connection
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
.db
|
||||||
)
|
.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
topic_embeddings_#{model.name.underscore}
|
||||||
|
ORDER BY
|
||||||
|
embedding #{model.pg_function} '[:query_embedding]'
|
||||||
|
LIMIT :limit
|
||||||
|
OFFSET :offset
|
||||||
|
SQL
|
||||||
|
.map(&:topic_id)
|
||||||
|
|
||||||
|
raise StandardError, "No embeddings found for topic #{topic.id}" if candidate_ids.empty?
|
||||||
|
|
||||||
|
candidate_ids
|
||||||
end
|
end
|
||||||
|
|
||||||
def openai_embeddings(model)
|
private
|
||||||
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
|
|
||||||
response[:data].first[:embedding]
|
def persist_embedding(topic, model, embedding)
|
||||||
|
DiscourseAi::Database::Connection.db.exec(<<~SQL, topic_id: topic.id, embedding: embedding)
|
||||||
|
INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding)
|
||||||
|
VALUES (:topic_id, '[:embedding]')
|
||||||
|
ON CONFLICT (topic_id)
|
||||||
|
DO UPDATE SET embedding = '[:embedding]'
|
||||||
|
SQL
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
desc "Creates tables to store embeddings"
|
desc "Creates tables to store embeddings"
|
||||||
task "ai:embeddings:create_table" => [:environment] do
|
task "ai:embeddings:create_table" => [:environment] do
|
||||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||||
CREATE EXTENSION IF NOT EXISTS pg_vector;
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
|
||||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||||
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
|
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
|
||||||
topic_id bigint PRIMARY KEY,
|
topic_id bigint PRIMARY KEY,
|
||||||
|
@ -19,12 +19,13 @@ end
|
||||||
desc "Backfill embeddings for all topics"
|
desc "Backfill embeddings for all topics"
|
||||||
task "ai:embeddings:backfill" => [:environment] do
|
task "ai:embeddings:backfill" => [:environment] do
|
||||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
|
topic_embeddings = DiscourseAi::Embeddings::Topic.new
|
||||||
Topic
|
Topic
|
||||||
.where("category_id IN ?", public_categories)
|
.where("category_id IN (?)", public_categories)
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.find_each do |t|
|
.find_each do |t|
|
||||||
print "."
|
print "."
|
||||||
DiscourseAI::Embeddings::Topic.new(t).perform!
|
topic_embeddings.generate_and_store_embeddings_for(t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -36,14 +37,14 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
||||||
lists = 4 * Math.sqrt(Topic.count).to_i
|
lists = 4 * Math.sqrt(Topic.count).to_i
|
||||||
|
|
||||||
DiscourseAi::Database::Connection.db.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
DiscourseAi::Database::Connection.db.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
||||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
|
||||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||||
CREATE INDEX IF NOT EXISTS
|
CREATE INDEX IF NOT EXISTS
|
||||||
topic_embeddings_#{model.name.underscore}_search
|
topic_embeddings_#{model.name.underscore}_search
|
||||||
ON
|
ON
|
||||||
topic_embeddings_#{model.name.underscore}
|
topic_embeddings_#{model.name.underscore}
|
||||||
USING
|
USING
|
||||||
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
|
ivfflat (embedding #{model.pg_index})
|
||||||
WITH
|
WITH
|
||||||
(lists = #{lists});
|
(lists = #{lists});
|
||||||
SQL
|
SQL
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../support/embeddings_generation_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Model do
|
||||||
|
describe "#generate_embedding" do
|
||||||
|
let(:input) { "test" }
|
||||||
|
let(:expected_embedding) { [0.0038493, 0.482001] }
|
||||||
|
|
||||||
|
context "when the model uses the discourse service to create embeddings" do
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
let(:discourse_model) { "all-mpnet-base-v2" }
|
||||||
|
|
||||||
|
it "returns an embedding for a given string" do
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
|
||||||
|
|
||||||
|
embedding = described_class.instantiate(discourse_model).generate_embedding(input)
|
||||||
|
|
||||||
|
expect(embedding).to contain_exactly(*expected_embedding)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the model uses OpenAI to create embeddings" do
|
||||||
|
let(:openai_model) { "text-embedding-ada-002" }
|
||||||
|
|
||||||
|
it "returns an embedding for a given string" do
|
||||||
|
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
|
||||||
|
|
||||||
|
embedding = described_class.instantiate(openai_model).generate_embedding(input)
|
||||||
|
|
||||||
|
expect(embedding).to contain_exactly(*expected_embedding)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -17,9 +17,10 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
describe "#candidates_for" do
|
describe "#candidates_for" do
|
||||||
before do
|
before do
|
||||||
Discourse.cache.clear
|
Discourse.cache.clear
|
||||||
described_class.stubs(:search_suggestions).returns(
|
DiscourseAi::Embeddings::Topic
|
||||||
Topic.unscoped.order(id: :desc).limit(10).pluck(:id),
|
.any_instance
|
||||||
)
|
.expects(:symmetric_semantic_search)
|
||||||
|
.returns(Topic.unscoped.order(id: :desc).limit(10).pluck(:id))
|
||||||
end
|
end
|
||||||
|
|
||||||
after { Discourse.cache.clear }
|
after { Discourse.cache.clear }
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
|
fab!(:post) { Fabricate(:post) }
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
let(:model_name) { "msmarco-distilbert-base-v4" }
|
||||||
|
let(:query) { "test_query" }
|
||||||
|
|
||||||
|
let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) }
|
||||||
|
let(:subject) { described_class.new(Guardian.new(user), model) }
|
||||||
|
|
||||||
|
describe "#search_for_topics" do
|
||||||
|
def stub_candidate_ids(candidate_ids)
|
||||||
|
DiscourseAi::Embeddings::Topic
|
||||||
|
.any_instance
|
||||||
|
.expects(:asymmetric_semantic_search)
|
||||||
|
.returns(candidate_ids)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns the first post of a topic included in the asymmetric search results" do
|
||||||
|
stub_candidate_ids([post.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).to contain_exactly(post)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "applies different scopes to the candidates" do
|
||||||
|
context "when the topic is not visible" do
|
||||||
|
it "returns an empty list" do
|
||||||
|
post.topic.update!(visible: false)
|
||||||
|
stub_candidate_ids([post.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).to be_empty
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the post is not public" do
|
||||||
|
it "returns an empty list" do
|
||||||
|
pm_post = Fabricate(:private_message_post)
|
||||||
|
stub_candidate_ids([pm_post.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).to be_empty
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the post type is not visible" do
|
||||||
|
it "returns an empty list" do
|
||||||
|
post.update!(post_type: Post.types[:whisper])
|
||||||
|
stub_candidate_ids([post.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).to be_empty
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the post is not the first post in the topic" do
|
||||||
|
it "returns an empty list" do
|
||||||
|
reply = Fabricate(:reply)
|
||||||
|
reply.topic.first_post.trash!
|
||||||
|
stub_candidate_ids([reply.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).to be_empty
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when the post is not a candidate" do
|
||||||
|
it "doesn't include it in the results" do
|
||||||
|
post_2 = Fabricate(:post)
|
||||||
|
stub_candidate_ids([post.topic_id])
|
||||||
|
|
||||||
|
posts = subject.search_for_topics(query)
|
||||||
|
|
||||||
|
expect(posts).not_to include(post_2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -19,9 +19,10 @@ describe ::TopicsController do
|
||||||
|
|
||||||
context "when a user is logged on" do
|
context "when a user is logged on" do
|
||||||
it "includes related topics in payload when configured" do
|
it "includes related topics in payload when configured" do
|
||||||
DiscourseAi::Embeddings::SemanticRelated.stubs(:search_suggestions).returns(
|
DiscourseAi::Embeddings::Topic
|
||||||
[topic1.id, topic2.id, topic3.id],
|
.any_instance
|
||||||
)
|
.expects(:symmetric_semantic_search)
|
||||||
|
.returns([topic1.id, topic2.id, topic3.id])
|
||||||
|
|
||||||
get("#{topic.relative_url}.json")
|
get("#{topic.relative_url}.json")
|
||||||
json = response.parsed_body
|
json = response.parsed_body
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class EmbeddingsGenerationStubs
|
||||||
|
class << self
|
||||||
|
def discourse_service(model, string, embedding)
|
||||||
|
WebMock
|
||||||
|
.stub_request(
|
||||||
|
:post,
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
)
|
||||||
|
.with(body: JSON.dump({ model: model, content: string }))
|
||||||
|
.to_return(status: 200, body: JSON.dump(embedding))
|
||||||
|
end
|
||||||
|
|
||||||
|
def openai_service(model, string, embedding)
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "https://api.openai.com/v1/embeddings")
|
||||||
|
.with(body: JSON.dump({ model: model, input: string }))
|
||||||
|
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue