FEATURE: Semantic Suggested Topics (#10)
This commit is contained in:
parent
f99fe7e1ed
commit
80d662e9e8
|
@ -31,11 +31,21 @@ en:
|
|||
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
||||
ai_nsfw_models: "Models to use for NSFW inference."
|
||||
|
||||
ai_openai_api_key: "API key for OpenAI API"
|
||||
|
||||
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
||||
ai_openai_api_key: "API key for the AI helper"
|
||||
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
|
||||
|
||||
ai_embeddings_enabled: "Enable the embeddings module."
|
||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
|
||||
ai_embeddings_models: "Discourse will generate embeddings for each of the models enabled here"
|
||||
ai_embeddings_semantic_suggested_model: "Model to use for suggested topics."
|
||||
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
|
||||
ai_embeddings_semantic_suggested_topics_anons_enabled: "Use Semantic Search for suggested topics for anonymous users."
|
||||
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||
|
||||
reviewables:
|
||||
reasons:
|
||||
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
||||
|
|
|
@ -101,3 +101,32 @@ plugins:
|
|||
ai_helper_allowed_in_pm:
|
||||
default: false
|
||||
client: true
|
||||
|
||||
ai_embeddings_enabled: false
|
||||
ai_embeddings_discourse_service_api_endpoint: ""
|
||||
ai_embeddings_discourse_service_api_key: ""
|
||||
ai_embeddings_models:
|
||||
type: list
|
||||
list_type: compact
|
||||
default: ""
|
||||
allow_any: false
|
||||
choices:
|
||||
- all-mpnet-base-v2
|
||||
- all-distilroberta-v1
|
||||
- multi-qa-mpnet-base-dot-v1
|
||||
- paraphrase-multilingual-mpnet-base-v2
|
||||
- msmarco-distilbert-base-v4
|
||||
- msmarco-distilbert-base-tas-b
|
||||
- text-embedding-ada-002
|
||||
ai_embeddings_semantic_suggested_model:
|
||||
type: enum
|
||||
default: all-mpnet-base-v2
|
||||
choices:
|
||||
- all-mpnet-base-v2
|
||||
- text-embedding-ada-002
|
||||
- all-distilroberta-v1
|
||||
- multi-qa-mpnet-base-dot-v1
|
||||
- paraphrase-multilingual-mpnet-base-v2
|
||||
ai_embeddings_generate_for_pms: false
|
||||
ai_embeddings_semantic_suggested_topics_anons_enabled: false
|
||||
ai_embeddings_pg_connection_string: ""
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "models"
|
||||
require_relative "topic"
|
||||
require_relative "jobs/regular/generate_embeddings"
|
||||
require_relative "semantic_suggested"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
callback =
|
||||
Proc.new do |topic|
|
||||
if SiteSetting.ai_embeddings_enabled
|
||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
||||
end
|
||||
end
|
||||
|
||||
plugin.on(:topic_created, &callback)
|
||||
plugin.on(:topic_edited, &callback)
|
||||
|
||||
DiscoursePluginRegistry.register_list_suggested_for_provider(
|
||||
SemanticSuggested.method(:build_suggested_topics),
|
||||
plugin,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,17 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module Jobs
|
||||
class GenerateEmbeddings < ::Jobs::Base
|
||||
def execute(args)
|
||||
return unless SiteSetting.ai_embeddings_enabled
|
||||
return if (topic_id = args[:topic_id]).blank?
|
||||
|
||||
topic = Topic.find_by_id(topic_id)
|
||||
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||
post = Topic.find_by_id(topic_id).first_post
|
||||
return if post.nil? || post.raw.blank?
|
||||
|
||||
DiscourseAi::Embeddings::Topic.new(post.topic).perform!
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,62 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Models
|
||||
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_INDEX = {
|
||||
dot: "vector_ip_ops",
|
||||
cosine: "vector_cosine_ops",
|
||||
euclidean: "vector_l2_ops",
|
||||
}
|
||||
|
||||
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
||||
|
||||
def self.enabled_models
|
||||
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
|
||||
list.filter { |model| setting.include?(model.name) }
|
||||
end
|
||||
|
||||
def self.list
|
||||
@@list ||= [
|
||||
MODEL.new(
|
||||
"all-mpnet-base-v2",
|
||||
768,
|
||||
384,
|
||||
%i[dot cosine euclidean],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new(
|
||||
"all-distilroberta-v1",
|
||||
768,
|
||||
512,
|
||||
%i[dot cosine euclidean],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
|
||||
MODEL.new(
|
||||
"paraphrase-multilingual-mpnet-base-v2",
|
||||
768,
|
||||
128,
|
||||
[:cosine],
|
||||
[:symmetric],
|
||||
"discourse",
|
||||
),
|
||||
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
|
||||
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
|
||||
MODEL.new(
|
||||
"text-embedding-ada-002",
|
||||
1536,
|
||||
2048,
|
||||
[:cosine],
|
||||
%i[:symmetric :asymmetric],
|
||||
"openai",
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,72 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class SemanticSuggested
|
||||
def self.build_suggested_topics(topic, pm_params, topic_query)
|
||||
return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled
|
||||
return if topic_query.user
|
||||
return if topic.private_message?
|
||||
|
||||
cache_for =
|
||||
case topic.created_at
|
||||
when 6.hour.ago..Time.now
|
||||
15.minutes
|
||||
when 1.day.ago..6.hour.ago
|
||||
1.hour
|
||||
else
|
||||
1.day
|
||||
end
|
||||
|
||||
begin
|
||||
candidate_ids =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
|
||||
suggested = search_suggestions(topic)
|
||||
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
if suggested.empty? || !suggested.include?(topic.id)
|
||||
return { result: [], params: {} }
|
||||
end
|
||||
|
||||
suggested
|
||||
end
|
||||
rescue StandardError => e
|
||||
Rails.logger.error("SemanticSuggested: #{e}")
|
||||
end
|
||||
|
||||
# array_position forces the order of the topics to be preserved
|
||||
candidates =
|
||||
::Topic.where(id: candidate_ids).order("array_position(ARRAY#{candidate_ids}, id)")
|
||||
|
||||
{ result: candidates, params: {} }
|
||||
end
|
||||
|
||||
def self.search_suggestions(topic)
|
||||
model_name = SiteSetting.ai_embeddings_semantic_suggested_model
|
||||
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
|
||||
function =
|
||||
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
|
||||
|
||||
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
topic_embeddings_#{model_name.underscore}
|
||||
ORDER BY
|
||||
embedding #{function} (
|
||||
SELECT
|
||||
embedding
|
||||
FROM
|
||||
topic_embeddings_#{model_name.underscore}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 11
|
||||
SQL
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,57 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Topic
|
||||
def initialize(topic)
|
||||
@topic = topic
|
||||
@embeddings = {}
|
||||
end
|
||||
|
||||
def perform!
|
||||
return unless SiteSetting.ai_embeddings_enabled
|
||||
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
|
||||
|
||||
calculate_embeddings!
|
||||
persist_embeddings! unless @embeddings.empty?
|
||||
end
|
||||
|
||||
def calculate_embeddings!
|
||||
return if @topic.blank? || @topic.first_post.blank?
|
||||
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
|
||||
end
|
||||
end
|
||||
|
||||
def persist_embeddings!
|
||||
@embeddings.each do |model, model_embedding|
|
||||
DiscourseAi::Database::Connection.db.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
|
||||
VALUES (:topic_id, '[:embedding]')
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET embedding = '[:embedding]'
|
||||
SQL
|
||||
topic_id: @topic.id,
|
||||
embedding: model_embedding,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def discourse_embeddings(model)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
model.to_s,
|
||||
@topic.first_post.raw,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def openai_embeddings(model)
|
||||
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,16 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Database
|
||||
class Connection
|
||||
def self.connect!
|
||||
pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string)
|
||||
@@db = MiniSql::Connection.get(pg_conn)
|
||||
end
|
||||
|
||||
def self.db
|
||||
@@db ||= connect!
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
desc "Creates tables to store embeddings"
|
||||
task "ai:embeddings:create_table" => [:environment] do
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
|
||||
topic_id bigint PRIMARY KEY,
|
||||
embedding vector(#{model.dimensions})
|
||||
);
|
||||
SQL
|
||||
end
|
||||
end
|
||||
|
||||
desc "Backfill embeddings for all topics"
|
||||
task "ai:embeddings:backfill" => [:environment] do
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
Topic
|
||||
.where("category_id IN ?", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.find_each do |t|
|
||||
print "."
|
||||
DiscourseAI::Embeddings::Topic.new(t).perform!
|
||||
end
|
||||
end
|
||||
|
||||
desc "Creates indexes for embeddings"
|
||||
task "ai:embeddings:index" => [:environment] do
|
||||
# Using 4 * sqrt(number of topics) as a rule of thumb for now
|
||||
# Results are not as good as without indexes, but it's much faster
|
||||
# Disk usage is ~1x the size of the table, so this double table total size
|
||||
lists = 4 * Math.sqrt(Topic.count).to_i
|
||||
|
||||
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||
CREATE INDEX IF NOT EXISTS
|
||||
topic_embeddings_#{model.name.underscore}_search
|
||||
ON
|
||||
topic_embeddings_#{model.name.underscore}
|
||||
USING
|
||||
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
|
||||
WITH
|
||||
(lists = #{lists});
|
||||
SQL
|
||||
end
|
||||
end
|
|
@ -27,12 +27,16 @@ after_initialize do
|
|||
require_relative "lib/shared/post_classificator"
|
||||
require_relative "lib/shared/chat_message_classificator"
|
||||
|
||||
require_relative "lib/shared/database/connection"
|
||||
|
||||
require_relative "lib/modules/nsfw/entry_point"
|
||||
require_relative "lib/modules/toxicity/entry_point"
|
||||
require_relative "lib/modules/sentiment/entry_point"
|
||||
require_relative "lib/modules/ai_helper/entry_point"
|
||||
require_relative "lib/modules/embeddings/entry_point"
|
||||
|
||||
[
|
||||
DiscourseAi::Embeddings::EntryPoint.new,
|
||||
DiscourseAi::NSFW::EntryPoint.new,
|
||||
DiscourseAi::Toxicity::EntryPoint.new,
|
||||
DiscourseAi::Sentiment::EntryPoint.new,
|
||||
|
|
Loading…
Reference in New Issue