FEATURE: Semantic assymetric full-page search (#34)

Depends on discourse/discourse#20915

Hooks to the full-page-search component using an experimental API and performs an assymetric similarity search using our embeddings database.
This commit is contained in:
Roman Rizzi 2023-03-31 15:29:56 -03:00 committed by GitHub
parent 99886fb64d
commit 4e05763a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 462 additions and 148 deletions

View File

@ -0,0 +1,35 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class EmbeddingsController < ::ApplicationController
requires_plugin ::DiscourseAi::PLUGIN_NAME
SEMANTIC_SEARCH_TYPE = "semantic_search"
def search
query = params[:q]
page = (params[:page] || 1).to_i
grouped_results =
Search::GroupedSearchResults.new(
type_filter: SEMANTIC_SEARCH_TYPE,
term: query,
search_context: guardian,
)
model =
DiscourseAi::Embeddings::Model.instantiate(
SiteSetting.ai_embeddings_semantic_search_model,
)
DiscourseAi::Embeddings::SemanticSearch
.new(guardian, model)
.search_for_topics(query, page)
.each { |topic_post| grouped_results.add(topic_post) }
render_serialized(grouped_results, GroupedSearchResultSerializer, result: grouped_results)
end
end
end
end

View File

@ -0,0 +1,64 @@
import { withPluginApi } from "discourse/lib/plugin-api";
import { translateResults, updateRecentSearches } from "discourse/lib/search";
import { setTransient } from "discourse/lib/page-tracker";
import { ajax } from "discourse/lib/ajax";
const SEMANTIC_SEARCH = "semantic_search";
function initializeSemanticSearch(api) {
api.addFullPageSearchType(
"discourse_ai.embeddings.semantic_search",
SEMANTIC_SEARCH,
(searchController, args, searchKey) => {
if (searchController.currentUser) {
updateRecentSearches(searchController.currentUser, args.searchTerm);
}
ajax("/discourse-ai/embeddings/semantic-search", { data: args })
.then(async (results) => {
const model = (await translateResults(results)) || {};
if (results.grouped_search_result) {
searchController.set("q", results.grouped_search_result.term);
}
if (args.page > 1) {
if (model) {
searchController.model.posts.pushObjects(model.posts);
searchController.model.topics.pushObjects(model.topics);
searchController.model.set(
"grouped_search_result",
results.grouped_search_result
);
}
} else {
setTransient("lastSearch", { searchKey, model }, 5);
model.grouped_search_result = results.grouped_search_result;
searchController.set("model", model);
}
searchController.set("error", null);
})
.catch((e) => {
searchController.set("error", e.jqXHR.responseJSON?.message);
})
.finally(() => {
searchController.setProperties({
searching: false,
loading: false,
});
});
}
);
}
export default {
name: "discourse_ai-full-page-semantic-search",
initialize(container) {
const settings = container.lookup("site-settings:main");
if (settings.ai_embeddings_enabled) {
withPluginApi("1.6.0", initializeSemanticSearch);
}
},
};

View File

@ -10,6 +10,9 @@ en:
reviewables:
model_used: "Model used:"
accuracy: "Accuracy:"
embeddings:
semantic_search: "Topics (Semantic)"
review:
types:
reviewable_ai_post:

View File

@ -47,6 +47,7 @@ en:
ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics."
ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section."
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
ai_embeddings_semantic_search_model: "Model to use for semantic search."
reviewables:
reasons:

View File

@ -6,6 +6,11 @@ DiscourseAi::Engine.routes.draw do
get "prompts" => "assistant#prompts"
post "suggest" => "assistant#suggest"
end
# Embedding routes
scope module: :embeddings, path: "/embeddings", defaults: { format: :json } do
get "semantic-search" => "embeddings#search"
end
end
Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" }

View File

@ -108,7 +108,9 @@ plugins:
- gpt-3.5-turbo
- gpt-4
ai_embeddings_enabled: false
ai_embeddings_enabled:
default: false
client: true
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_key: ""
ai_embeddings_models:
@ -133,6 +135,13 @@ plugins:
- all-distilroberta-v1
- multi-qa-mpnet-base-dot-v1
- paraphrase-multilingual-mpnet-base-v2
ai_embeddings_semantic_search_model:
type: enum
default: msmarco-distilbert-base-v4
choices:
- msmarco-distilbert-base-v4
- msmarco-distilbert-base-tas-b
- text-embedding-ada-002
ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled: false
ai_embeddings_semantic_related_topics: 5

View File

@ -4,10 +4,11 @@ module DiscourseAi
module Embeddings
class EntryPoint
def load_files
require_relative "models"
require_relative "model"
require_relative "topic"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_related"
require_relative "semantic_search"
end
def inject_into(plugin)

View File

@ -0,0 +1,83 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Model
AVAILABLE_MODELS_TEMPLATES = {
"all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
"all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
"multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
"paraphrase-multilingual-mpnet-base-v2" => [
768,
128,
%i[cosine],
%i[symmetric],
"discourse",
],
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
}
SEARCH_FUNCTION_TO_PG_INDEX = {
dot: "vector_ip_ops",
cosine: "vector_cosine_ops",
euclidean: "vector_l2_ops",
}
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
class << self
def instantiate(model_name)
new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
end
def enabled_models
SiteSetting
.ai_embeddings_models
.split("|")
.map { |model_name| instantiate(model_name.strip) }
end
end
def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
@name = name
@dimensions = dimensions
@max_sequence_lenght = max_sequence_lenght
@functions = functions
@type = type
@provider = provider
end
def generate_embedding(input)
send("#{provider}_embeddings", input)
end
def pg_function
SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
end
def pg_index
SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
end
attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
private
def discourse_embeddings(input)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name.to_s,
input,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def openai_embeddings(input)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input)
response[:data].first[:embedding]
end
end
end
end

View File

@ -1,62 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Models
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
SEARCH_FUNCTION_TO_PG_INDEX = {
dot: "vector_ip_ops",
cosine: "vector_cosine_ops",
euclidean: "vector_l2_ops",
}
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
def self.enabled_models
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
list.filter { |model| setting.include?(model.name) }
end
def self.list
@@list ||= [
MODEL.new(
"all-mpnet-base-v2",
768,
384,
%i[dot cosine euclidean],
[:symmetric],
"discourse",
),
MODEL.new(
"all-distilroberta-v1",
768,
512,
%i[dot cosine euclidean],
[:symmetric],
"discourse",
),
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
MODEL.new(
"paraphrase-multilingual-mpnet-base-v2",
768,
128,
[:cosine],
[:symmetric],
"discourse",
),
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
MODEL.new(
"text-embedding-ada-002",
1536,
2048,
[:cosine],
%i[:symmetric :asymmetric],
"openai",
),
]
end
end
end
end

View File

@ -16,12 +16,17 @@ module DiscourseAi
1.day
end
model =
DiscourseAi::Embeddings::Model.instantiate(
SiteSetting.ai_embeddings_semantic_related_model,
)
begin
candidate_ids =
Discourse
.cache
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
search_suggestions(topic)
DiscourseAi::Embeddings::Topic.new.symmetric_semantic_search(model, topic)
end
rescue StandardError => e
Rails.logger.error("SemanticRelated: #{e}")
@ -39,40 +44,6 @@ module DiscourseAi
.order("array_position(ARRAY#{candidate_ids}, id)")
.limit(SiteSetting.ai_embeddings_semantic_related_topics)
end
def self.search_suggestions(topic)
model_name = SiteSetting.ai_embeddings_semantic_related_model
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
function =
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
candidate_ids =
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
topic_embeddings_#{model_name.underscore}
ORDER BY
embedding #{function} (
SELECT
embedding
FROM
topic_embeddings_#{model_name.underscore}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 11
SQL
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise StandardError, "No embeddings found for topic #{topic.id}"
end
candidate_ids
end
end
end
end

View File

@ -0,0 +1,31 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class SemanticSearch
def initialize(guardian, model)
@guardian = guardian
@model = model
end
def search_for_topics(query, page = 1)
limit = Search.per_filter + 1
offset = (page - 1) * Search.per_filter
candidate_ids =
DiscourseAi::Embeddings::Topic.new.asymmetric_semantic_search(model, query, limit, offset)
::Post
.where(post_type: ::Topic.visible_post_types(guardian.user))
.public_posts
.where("topics.visible")
.where(topic_id: candidate_ids, post_number: 1)
.order("array_position(ARRAY#{candidate_ids}, topic_id)")
end
private
attr_reader :model, :guardian
end
end
end

View File

@ -3,54 +3,80 @@
module DiscourseAi
module Embeddings
class Topic
def initialize(topic)
@topic = topic
@embeddings = {}
end
def perform!
def generate_and_store_embeddings_for(topic)
return unless SiteSetting.ai_embeddings_enabled
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
return if topic.blank? || topic.first_post.blank?
calculate_embeddings!
persist_embeddings! unless @embeddings.empty?
end
enabled_models = DiscourseAi::Embeddings::Model.enabled_models
return if enabled_models.empty?
def calculate_embeddings!
return if @topic.blank? || @topic.first_post.blank?
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
enabled_models.each do |model|
embedding = model.generate_embedding(topic.first_post.raw)
persist_embedding(topic, model, embedding) if embedding
end
end
def persist_embeddings!
@embeddings.each do |model, model_embedding|
DiscourseAi::Database::Connection.db.exec(
<<~SQL,
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
VALUES (:topic_id, '[:embedding]')
ON CONFLICT (topic_id)
DO UPDATE SET embedding = '[:embedding]'
SQL
topic_id: @topic.id,
embedding: model_embedding,
)
def symmetric_semantic_search(model, topic)
candidate_ids =
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
topic_embeddings_#{model.name.underscore}
ORDER BY
embedding #{model.pg_function} (
SELECT
embedding
FROM
topic_embeddings_#{model.name.underscore}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 11
SQL
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
raise StandardError, "No embeddings found for topic #{topic.id}"
end
candidate_ids
end
def discourse_embeddings(model)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
model.to_s,
@topic.first_post.raw,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
def asymmetric_semantic_search(model, query, limit, offset)
query_embedding = model.generate_embedding(query)
candidate_ids =
DiscourseAi::Database::Connection
.db
.query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset)
SELECT
topic_id
FROM
topic_embeddings_#{model.name.underscore}
ORDER BY
embedding #{model.pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
.map(&:topic_id)
raise StandardError, "No embeddings found for topic #{topic.id}" if candidate_ids.empty?
candidate_ids
end
def openai_embeddings(model)
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
response[:data].first[:embedding]
private
def persist_embedding(topic, model, embedding)
DiscourseAi::Database::Connection.db.exec(<<~SQL, topic_id: topic.id, embedding: embedding)
INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding)
VALUES (:topic_id, '[:embedding]')
ON CONFLICT (topic_id)
DO UPDATE SET embedding = '[:embedding]'
SQL
end
end
end

View File

@ -3,10 +3,10 @@
desc "Creates tables to store embeddings"
task "ai:embeddings:create_table" => [:environment] do
DiscourseAi::Database::Connection.db.exec(<<~SQL)
CREATE EXTENSION IF NOT EXISTS pg_vector;
CREATE EXTENSION IF NOT EXISTS vector;
SQL
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
topic_id bigint PRIMARY KEY,
@ -19,12 +19,13 @@ end
desc "Backfill embeddings for all topics"
task "ai:embeddings:backfill" => [:environment] do
public_categories = Category.where(read_restricted: false).pluck(:id)
topic_embeddings = DiscourseAi::Embeddings::Topic.new
Topic
.where("category_id IN ?", public_categories)
.where("category_id IN (?)", public_categories)
.where(deleted_at: nil)
.find_each do |t|
print "."
DiscourseAI::Embeddings::Topic.new(t).perform!
topic_embeddings.generate_and_store_embeddings_for(t)
end
end
@ -36,14 +37,14 @@ task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
lists = 4 * Math.sqrt(Topic.count).to_i
DiscourseAi::Database::Connection.db.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
DiscourseAi::Embeddings::Model.enabled_models.each do |model|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
CREATE INDEX IF NOT EXISTS
topic_embeddings_#{model.name.underscore}_search
ON
topic_embeddings_#{model.name.underscore}
USING
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
ivfflat (embedding #{model.pg_index})
WITH
(lists = #{lists});
SQL

View File

@ -0,0 +1,36 @@
# frozen_string_literal: true
require_relative "../../../support/embeddings_generation_stubs"
RSpec.describe DiscourseAi::Embeddings::Model do
describe "#generate_embedding" do
let(:input) { "test" }
let(:expected_embedding) { [0.0038493, 0.482001] }
context "when the model uses the discourse service to create embeddings" do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
let(:discourse_model) { "all-mpnet-base-v2" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.discourse_service(discourse_model, input, expected_embedding)
embedding = described_class.instantiate(discourse_model).generate_embedding(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
context "when the model uses OpenAI to create embeddings" do
let(:openai_model) { "text-embedding-ada-002" }
it "returns an embedding for a given string" do
EmbeddingsGenerationStubs.openai_service(openai_model, input, expected_embedding)
embedding = described_class.instantiate(openai_model).generate_embedding(input)
expect(embedding).to contain_exactly(*expected_embedding)
end
end
end
end

View File

@ -17,9 +17,10 @@ describe DiscourseAi::Embeddings::SemanticRelated do
describe "#candidates_for" do
before do
Discourse.cache.clear
described_class.stubs(:search_suggestions).returns(
Topic.unscoped.order(id: :desc).limit(10).pluck(:id),
)
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:symmetric_semantic_search)
.returns(Topic.unscoped.order(id: :desc).limit(10).pluck(:id))
end
after { Discourse.cache.clear }

View File

@ -0,0 +1,86 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
fab!(:post) { Fabricate(:post) }
fab!(:user) { Fabricate(:user) }
let(:model_name) { "msmarco-distilbert-base-v4" }
let(:query) { "test_query" }
let(:model) { DiscourseAi::Embeddings::Model.instantiate(model_name) }
let(:subject) { described_class.new(Guardian.new(user), model) }
describe "#search_for_topics" do
def stub_candidate_ids(candidate_ids)
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:asymmetric_semantic_search)
.returns(candidate_ids)
end
it "returns the first post of a topic included in the asymmetric search results" do
stub_candidate_ids([post.topic_id])
posts = subject.search_for_topics(query)
expect(posts).to contain_exactly(post)
end
describe "applies different scopes to the candidates" do
context "when the topic is not visible" do
it "returns an empty list" do
post.topic.update!(visible: false)
stub_candidate_ids([post.topic_id])
posts = subject.search_for_topics(query)
expect(posts).to be_empty
end
end
context "when the post is not public" do
it "returns an empty list" do
pm_post = Fabricate(:private_message_post)
stub_candidate_ids([pm_post.topic_id])
posts = subject.search_for_topics(query)
expect(posts).to be_empty
end
end
context "when the post type is not visible" do
it "returns an empty list" do
post.update!(post_type: Post.types[:whisper])
stub_candidate_ids([post.topic_id])
posts = subject.search_for_topics(query)
expect(posts).to be_empty
end
end
context "when the post is not the first post in the topic" do
it "returns an empty list" do
reply = Fabricate(:reply)
reply.topic.first_post.trash!
stub_candidate_ids([reply.topic_id])
posts = subject.search_for_topics(query)
expect(posts).to be_empty
end
end
context "when the post is not a candidate" do
it "doesn't include it in the results" do
post_2 = Fabricate(:post)
stub_candidate_ids([post.topic_id])
posts = subject.search_for_topics(query)
expect(posts).not_to include(post_2)
end
end
end
end
end

View File

@ -19,9 +19,10 @@ describe ::TopicsController do
context "when a user is logged on" do
it "includes related topics in payload when configured" do
DiscourseAi::Embeddings::SemanticRelated.stubs(:search_suggestions).returns(
[topic1.id, topic2.id, topic3.id],
)
DiscourseAi::Embeddings::Topic
.any_instance
.expects(:symmetric_semantic_search)
.returns([topic1.id, topic2.id, topic3.id])
get("#{topic.relative_url}.json")
json = response.parsed_body

View File

@ -0,0 +1,22 @@
# frozen_string_literal: true
class EmbeddingsGenerationStubs
class << self
def discourse_service(model, string, embedding)
WebMock
.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
)
.with(body: JSON.dump({ model: model, content: string }))
.to_return(status: 200, body: JSON.dump(embedding))
end
def openai_service(model, string, embedding)
WebMock
.stub_request(:post, "https://api.openai.com/v1/embeddings")
.with(body: JSON.dump({ model: model, input: string }))
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
end
end
end