FEATURE: Semantic Suggested Topics (#10)

This commit is contained in:
Rafael dos Santos Silva 2023-03-15 17:21:45 -03:00 committed by GitHub
parent f99fe7e1ed
commit 80d662e9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 345 additions and 1 deletions

View File

@ -31,11 +31,21 @@ en:
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
ai_nsfw_models: "Models to use for NSFW inference."
ai_openai_api_key: "API key for OpenAI API"
composer_ai_helper_enabled: "Enable the Composer's AI helper."
ai_openai_api_key: "API key for the AI helper"
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
ai_embeddings_models: "Discourse will generate embeddings for each of the models enabled here"
ai_embeddings_semantic_suggested_model: "Model to use for suggested topics."
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
ai_embeddings_semantic_suggested_topics_anons_enabled: "Use Semantic Search for suggested topics for anonymous users."
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
reviewables:
reasons:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.

View File

@ -101,3 +101,32 @@ plugins:
ai_helper_allowed_in_pm:
default: false
client: true
ai_embeddings_enabled: false
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_key: ""
ai_embeddings_models:
type: list
list_type: compact
default: ""
allow_any: false
choices:
- all-mpnet-base-v2
- all-distilroberta-v1
- multi-qa-mpnet-base-dot-v1
- paraphrase-multilingual-mpnet-base-v2
- msmarco-distilbert-base-v4
- msmarco-distilbert-base-tas-b
- text-embedding-ada-002
ai_embeddings_semantic_suggested_model:
type: enum
default: all-mpnet-base-v2
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- all-distilroberta-v1
- multi-qa-mpnet-base-dot-v1
- paraphrase-multilingual-mpnet-base-v2
ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_suggested_topics_anons_enabled: false
ai_embeddings_pg_connection_string: ""

View File

@ -0,0 +1,31 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class EntryPoint
def load_files
require_relative "models"
require_relative "topic"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_suggested"
end
def inject_into(plugin)
callback =
Proc.new do |topic|
if SiteSetting.ai_embeddings_enabled
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
end
end
plugin.on(:topic_created, &callback)
plugin.on(:topic_edited, &callback)
DiscoursePluginRegistry.register_list_suggested_for_provider(
SemanticSuggested.method(:build_suggested_topics),
plugin,
)
end
end
end
end

View File

@ -0,0 +1,17 @@
# frozen_string_literal: true
module Jobs
class GenerateEmbeddings < ::Jobs::Base
def execute(args)
return unless SiteSetting.ai_embeddings_enabled
return if (topic_id = args[:topic_id]).blank?
topic = Topic.find_by_id(topic_id)
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
post = Topic.find_by_id(topic_id).first_post
return if post.nil? || post.raw.blank?
DiscourseAi::Embeddings::Topic.new(post.topic).perform!
end
end
end

View File

@ -0,0 +1,62 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Models
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
SEARCH_FUNCTION_TO_PG_INDEX = {
dot: "vector_ip_ops",
cosine: "vector_cosine_ops",
euclidean: "vector_l2_ops",
}
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
def self.enabled_models
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
list.filter { |model| setting.include?(model.name) }
end
def self.list
@@list ||= [
MODEL.new(
"all-mpnet-base-v2",
768,
384,
%i[dot cosine euclidean],
[:symmetric],
"discourse",
),
MODEL.new(
"all-distilroberta-v1",
768,
512,
%i[dot cosine euclidean],
[:symmetric],
"discourse",
),
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
MODEL.new(
"paraphrase-multilingual-mpnet-base-v2",
768,
128,
[:cosine],
[:symmetric],
"discourse",
),
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
MODEL.new(
"text-embedding-ada-002",
1536,
2048,
[:cosine],
%i[:symmetric :asymmetric],
"openai",
),
]
end
end
end
end

View File

@ -0,0 +1,72 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class SemanticSuggested
def self.build_suggested_topics(topic, pm_params, topic_query)
return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled
return if topic_query.user
return if topic.private_message?
cache_for =
case topic.created_at
when 6.hour.ago..Time.now
15.minutes
when 1.day.ago..6.hour.ago
1.hour
else
1.day
end
begin
candidate_ids =
Discourse
.cache
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
suggested = search_suggestions(topic)
# Happens when the topic doesn't have any embeddings
if suggested.empty? || !suggested.include?(topic.id)
return { result: [], params: {} }
end
suggested
end
rescue StandardError => e
Rails.logger.error("SemanticSuggested: #{e}")
end
# array_position forces the order of the topics to be preserved
candidates =
::Topic.where(id: candidate_ids).order("array_position(ARRAY#{candidate_ids}, id)")
{ result: candidates, params: {} }
end
def self.search_suggestions(topic)
model_name = SiteSetting.ai_embeddings_semantic_suggested_model
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
function =
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
topic_embeddings_#{model_name.underscore}
ORDER BY
embedding #{function} (
SELECT
embedding
FROM
topic_embeddings_#{model_name.underscore}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 11
SQL
end
end
end
end

View File

@ -0,0 +1,57 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Topic
def initialize(topic)
@topic = topic
@embeddings = {}
end
def perform!
return unless SiteSetting.ai_embeddings_enabled
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
calculate_embeddings!
persist_embeddings! unless @embeddings.empty?
end
def calculate_embeddings!
return if @topic.blank? || @topic.first_post.blank?
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
end
end
def persist_embeddings!
@embeddings.each do |model, model_embedding|
DiscourseAi::Database::Connection.db.exec(
<<~SQL,
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
VALUES (:topic_id, '[:embedding]')
ON CONFLICT (topic_id)
DO UPDATE SET embedding = '[:embedding]'
SQL
topic_id: @topic.id,
embedding: model_embedding,
)
end
end
def discourse_embeddings(model)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
model.to_s,
@topic.first_post.raw,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def openai_embeddings(model)
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
response[:data].first[:embedding]
end
end
end
end

View File

@ -0,0 +1,16 @@
# frozen_string_literal: true
module ::DiscourseAi
module Database
class Connection
def self.connect!
pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string)
@@db = MiniSql::Connection.get(pg_conn)
end
def self.db
@@db ||= connect!
end
end
end
end

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
desc "Creates tables to store embeddings"
task "ai:embeddings:create_table" => [:environment] do
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
topic_id bigint PRIMARY KEY,
embedding vector(#{model.dimensions})
);
SQL
end
end
desc "Backfill embeddings for all topics"
task "ai:embeddings:backfill" => [:environment] do
public_categories = Category.where(read_restricted: false).pluck(:id)
Topic
.where("category_id IN ?", public_categories)
.where(deleted_at: nil)
.find_each do |t|
print "."
DiscourseAI::Embeddings::Topic.new(t).perform!
end
end
desc "Creates indexes for embeddings"
task "ai:embeddings:index" => [:environment] do
# Using 4 * sqrt(number of topics) as a rule of thumb for now
# Results are not as good as without indexes, but it's much faster
# Disk usage is ~1x the size of the table, so this double table total size
lists = 4 * Math.sqrt(Topic.count).to_i
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
CREATE INDEX IF NOT EXISTS
topic_embeddings_#{model.name.underscore}_search
ON
topic_embeddings_#{model.name.underscore}
USING
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
WITH
(lists = #{lists});
SQL
end
end

View File

@ -27,12 +27,16 @@ after_initialize do
require_relative "lib/shared/post_classificator"
require_relative "lib/shared/chat_message_classificator"
require_relative "lib/shared/database/connection"
require_relative "lib/modules/nsfw/entry_point"
require_relative "lib/modules/toxicity/entry_point"
require_relative "lib/modules/sentiment/entry_point"
require_relative "lib/modules/ai_helper/entry_point"
require_relative "lib/modules/embeddings/entry_point"
[
DiscourseAi::Embeddings::EntryPoint.new,
DiscourseAi::NSFW::EntryPoint.new,
DiscourseAi::Toxicity::EntryPoint.new,
DiscourseAi::Sentiment::EntryPoint.new,