FEATURE: Semantic Suggested Topics (#10)
This commit is contained in:
parent
f99fe7e1ed
commit
80d662e9e8
|
@ -31,11 +31,21 @@ en:
|
||||||
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
||||||
ai_nsfw_models: "Models to use for NSFW inference."
|
ai_nsfw_models: "Models to use for NSFW inference."
|
||||||
|
|
||||||
|
ai_openai_api_key: "API key for OpenAI API"
|
||||||
|
|
||||||
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
||||||
ai_openai_api_key: "API key for the AI helper"
|
|
||||||
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||||
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
|
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
|
||||||
|
|
||||||
|
ai_embeddings_enabled: "Enable the embeddings module."
|
||||||
|
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||||
|
ai_embeddings_discourse_service_api_key: "API key for the embeddings API"
|
||||||
|
ai_embeddings_models: "Discourse will generate embeddings for each of the models enabled here"
|
||||||
|
ai_embeddings_semantic_suggested_model: "Model to use for suggested topics."
|
||||||
|
ai_embeddings_generate_for_pms: "Generate embeddings for personal messages."
|
||||||
|
ai_embeddings_semantic_suggested_topics_anons_enabled: "Use Semantic Search for suggested topics for anonymous users."
|
||||||
|
ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info."
|
||||||
|
|
||||||
reviewables:
|
reviewables:
|
||||||
reasons:
|
reasons:
|
||||||
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
||||||
|
|
|
@ -101,3 +101,32 @@ plugins:
|
||||||
ai_helper_allowed_in_pm:
|
ai_helper_allowed_in_pm:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
|
||||||
|
ai_embeddings_enabled: false
|
||||||
|
ai_embeddings_discourse_service_api_endpoint: ""
|
||||||
|
ai_embeddings_discourse_service_api_key: ""
|
||||||
|
ai_embeddings_models:
|
||||||
|
type: list
|
||||||
|
list_type: compact
|
||||||
|
default: ""
|
||||||
|
allow_any: false
|
||||||
|
choices:
|
||||||
|
- all-mpnet-base-v2
|
||||||
|
- all-distilroberta-v1
|
||||||
|
- multi-qa-mpnet-base-dot-v1
|
||||||
|
- paraphrase-multilingual-mpnet-base-v2
|
||||||
|
- msmarco-distilbert-base-v4
|
||||||
|
- msmarco-distilbert-base-tas-b
|
||||||
|
- text-embedding-ada-002
|
||||||
|
ai_embeddings_semantic_suggested_model:
|
||||||
|
type: enum
|
||||||
|
default: all-mpnet-base-v2
|
||||||
|
choices:
|
||||||
|
- all-mpnet-base-v2
|
||||||
|
- text-embedding-ada-002
|
||||||
|
- all-distilroberta-v1
|
||||||
|
- multi-qa-mpnet-base-dot-v1
|
||||||
|
- paraphrase-multilingual-mpnet-base-v2
|
||||||
|
ai_embeddings_generate_for_pms: false
|
||||||
|
ai_embeddings_semantic_suggested_topics_anons_enabled: false
|
||||||
|
ai_embeddings_pg_connection_string: ""
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class EntryPoint
|
||||||
|
def load_files
|
||||||
|
require_relative "models"
|
||||||
|
require_relative "topic"
|
||||||
|
require_relative "jobs/regular/generate_embeddings"
|
||||||
|
require_relative "semantic_suggested"
|
||||||
|
end
|
||||||
|
|
||||||
|
def inject_into(plugin)
|
||||||
|
callback =
|
||||||
|
Proc.new do |topic|
|
||||||
|
if SiteSetting.ai_embeddings_enabled
|
||||||
|
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
plugin.on(:topic_created, &callback)
|
||||||
|
plugin.on(:topic_edited, &callback)
|
||||||
|
|
||||||
|
DiscoursePluginRegistry.register_list_suggested_for_provider(
|
||||||
|
SemanticSuggested.method(:build_suggested_topics),
|
||||||
|
plugin,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,17 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module Jobs
|
||||||
|
class GenerateEmbeddings < ::Jobs::Base
|
||||||
|
def execute(args)
|
||||||
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
|
return if (topic_id = args[:topic_id]).blank?
|
||||||
|
|
||||||
|
topic = Topic.find_by_id(topic_id)
|
||||||
|
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||||
|
post = Topic.find_by_id(topic_id).first_post
|
||||||
|
return if post.nil? || post.raw.blank?
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::Topic.new(post.topic).perform!
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,62 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class Models
|
||||||
|
MODEL = Data.define(:name, :dimensions, :max_sequence_lenght, :functions, :type, :provider)
|
||||||
|
|
||||||
|
SEARCH_FUNCTION_TO_PG_INDEX = {
|
||||||
|
dot: "vector_ip_ops",
|
||||||
|
cosine: "vector_cosine_ops",
|
||||||
|
euclidean: "vector_l2_ops",
|
||||||
|
}
|
||||||
|
|
||||||
|
SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
|
||||||
|
|
||||||
|
def self.enabled_models
|
||||||
|
setting = SiteSetting.ai_embeddings_models.split("|").map(&:strip)
|
||||||
|
list.filter { |model| setting.include?(model.name) }
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.list
|
||||||
|
@@list ||= [
|
||||||
|
MODEL.new(
|
||||||
|
"all-mpnet-base-v2",
|
||||||
|
768,
|
||||||
|
384,
|
||||||
|
%i[dot cosine euclidean],
|
||||||
|
[:symmetric],
|
||||||
|
"discourse",
|
||||||
|
),
|
||||||
|
MODEL.new(
|
||||||
|
"all-distilroberta-v1",
|
||||||
|
768,
|
||||||
|
512,
|
||||||
|
%i[dot cosine euclidean],
|
||||||
|
[:symmetric],
|
||||||
|
"discourse",
|
||||||
|
),
|
||||||
|
MODEL.new("multi-qa-mpnet-base-dot-v1", 768, 512, [:dot], [:symmetric], "discourse"),
|
||||||
|
MODEL.new(
|
||||||
|
"paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
768,
|
||||||
|
128,
|
||||||
|
[:cosine],
|
||||||
|
[:symmetric],
|
||||||
|
"discourse",
|
||||||
|
),
|
||||||
|
MODEL.new("msmarco-distilbert-base-v4", 768, 512, [:cosine], [:asymmetric], "discourse"),
|
||||||
|
MODEL.new("msmarco-distilbert-base-tas-b", 768, 512, [:dot], [:asymmetric], "discourse"),
|
||||||
|
MODEL.new(
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
[:cosine],
|
||||||
|
%i[:symmetric :asymmetric],
|
||||||
|
"openai",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,72 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class SemanticSuggested
|
||||||
|
def self.build_suggested_topics(topic, pm_params, topic_query)
|
||||||
|
return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled
|
||||||
|
return if topic_query.user
|
||||||
|
return if topic.private_message?
|
||||||
|
|
||||||
|
cache_for =
|
||||||
|
case topic.created_at
|
||||||
|
when 6.hour.ago..Time.now
|
||||||
|
15.minutes
|
||||||
|
when 1.day.ago..6.hour.ago
|
||||||
|
1.hour
|
||||||
|
else
|
||||||
|
1.day
|
||||||
|
end
|
||||||
|
|
||||||
|
begin
|
||||||
|
candidate_ids =
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch("semantic-suggested-topic-#{topic.id}", expires_in: cache_for) do
|
||||||
|
suggested = search_suggestions(topic)
|
||||||
|
|
||||||
|
# Happens when the topic doesn't have any embeddings
|
||||||
|
if suggested.empty? || !suggested.include?(topic.id)
|
||||||
|
return { result: [], params: {} }
|
||||||
|
end
|
||||||
|
|
||||||
|
suggested
|
||||||
|
end
|
||||||
|
rescue StandardError => e
|
||||||
|
Rails.logger.error("SemanticSuggested: #{e}")
|
||||||
|
end
|
||||||
|
|
||||||
|
# array_position forces the order of the topics to be preserved
|
||||||
|
candidates =
|
||||||
|
::Topic.where(id: candidate_ids).order("array_position(ARRAY#{candidate_ids}, id)")
|
||||||
|
|
||||||
|
{ result: candidates, params: {} }
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.search_suggestions(topic)
|
||||||
|
model_name = SiteSetting.ai_embeddings_semantic_suggested_model
|
||||||
|
model = DiscourseAi::Embeddings::Models.list.find { |m| m.name == model_name }
|
||||||
|
function =
|
||||||
|
DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_FUNCTION[model.functions.first]
|
||||||
|
|
||||||
|
DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||||
|
SELECT
|
||||||
|
topic_id
|
||||||
|
FROM
|
||||||
|
topic_embeddings_#{model_name.underscore}
|
||||||
|
ORDER BY
|
||||||
|
embedding #{function} (
|
||||||
|
SELECT
|
||||||
|
embedding
|
||||||
|
FROM
|
||||||
|
topic_embeddings_#{model_name.underscore}
|
||||||
|
WHERE
|
||||||
|
topic_id = :topic_id
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
LIMIT 11
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,57 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class Topic
|
||||||
|
def initialize(topic)
|
||||||
|
@topic = topic
|
||||||
|
@embeddings = {}
|
||||||
|
end
|
||||||
|
|
||||||
|
def perform!
|
||||||
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
|
return if DiscourseAi::Embeddings::Models.enabled_models.empty?
|
||||||
|
|
||||||
|
calculate_embeddings!
|
||||||
|
persist_embeddings! unless @embeddings.empty?
|
||||||
|
end
|
||||||
|
|
||||||
|
def calculate_embeddings!
|
||||||
|
return if @topic.blank? || @topic.first_post.blank?
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||||
|
@embeddings[model.name] = send("#{model.provider}_embeddings", model.name)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def persist_embeddings!
|
||||||
|
@embeddings.each do |model, model_embedding|
|
||||||
|
DiscourseAi::Database::Connection.db.exec(
|
||||||
|
<<~SQL,
|
||||||
|
INSERT INTO topic_embeddings_#{model.underscore} (topic_id, embedding)
|
||||||
|
VALUES (:topic_id, '[:embedding]')
|
||||||
|
ON CONFLICT (topic_id)
|
||||||
|
DO UPDATE SET embedding = '[:embedding]'
|
||||||
|
SQL
|
||||||
|
topic_id: @topic.id,
|
||||||
|
embedding: model_embedding,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def discourse_embeddings(model)
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
model.to_s,
|
||||||
|
@topic.first_post.raw,
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def openai_embeddings(model)
|
||||||
|
response = DiscourseAi::Inference::OpenAIEmbeddings.perform!(@topic.first_post.raw)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,16 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module ::DiscourseAi
|
||||||
|
module Database
|
||||||
|
class Connection
|
||||||
|
def self.connect!
|
||||||
|
pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string)
|
||||||
|
@@db = MiniSql::Connection.get(pg_conn)
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.db
|
||||||
|
@@db ||= connect!
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,46 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
desc "Creates tables to store embeddings"
|
||||||
|
task "ai:embeddings:create_table" => [:environment] do
|
||||||
|
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||||
|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||||
|
CREATE TABLE IF NOT EXISTS topic_embeddings_#{model.name.underscore} (
|
||||||
|
topic_id bigint PRIMARY KEY,
|
||||||
|
embedding vector(#{model.dimensions})
|
||||||
|
);
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
desc "Backfill embeddings for all topics"
|
||||||
|
task "ai:embeddings:backfill" => [:environment] do
|
||||||
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
|
Topic
|
||||||
|
.where("category_id IN ?", public_categories)
|
||||||
|
.where(deleted_at: nil)
|
||||||
|
.find_each do |t|
|
||||||
|
print "."
|
||||||
|
DiscourseAI::Embeddings::Topic.new(t).perform!
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
desc "Creates indexes for embeddings"
|
||||||
|
task "ai:embeddings:index" => [:environment] do
|
||||||
|
# Using 4 * sqrt(number of topics) as a rule of thumb for now
|
||||||
|
# Results are not as good as without indexes, but it's much faster
|
||||||
|
# Disk usage is ~1x the size of the table, so this double table total size
|
||||||
|
lists = 4 * Math.sqrt(Topic.count).to_i
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::Models.enabled_models.each do |model|
|
||||||
|
DiscourseAi::Database::Connection.db.exec(<<~SQL)
|
||||||
|
CREATE INDEX IF NOT EXISTS
|
||||||
|
topic_embeddings_#{model.name.underscore}_search
|
||||||
|
ON
|
||||||
|
topic_embeddings_#{model.name.underscore}
|
||||||
|
USING
|
||||||
|
ivfflat (embedding #{DiscourseAi::Embeddings::Models::SEARCH_FUNCTION_TO_PG_INDEX[model.functions.first]})
|
||||||
|
WITH
|
||||||
|
(lists = #{lists});
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
|
@ -27,12 +27,16 @@ after_initialize do
|
||||||
require_relative "lib/shared/post_classificator"
|
require_relative "lib/shared/post_classificator"
|
||||||
require_relative "lib/shared/chat_message_classificator"
|
require_relative "lib/shared/chat_message_classificator"
|
||||||
|
|
||||||
|
require_relative "lib/shared/database/connection"
|
||||||
|
|
||||||
require_relative "lib/modules/nsfw/entry_point"
|
require_relative "lib/modules/nsfw/entry_point"
|
||||||
require_relative "lib/modules/toxicity/entry_point"
|
require_relative "lib/modules/toxicity/entry_point"
|
||||||
require_relative "lib/modules/sentiment/entry_point"
|
require_relative "lib/modules/sentiment/entry_point"
|
||||||
require_relative "lib/modules/ai_helper/entry_point"
|
require_relative "lib/modules/ai_helper/entry_point"
|
||||||
|
require_relative "lib/modules/embeddings/entry_point"
|
||||||
|
|
||||||
[
|
[
|
||||||
|
DiscourseAi::Embeddings::EntryPoint.new,
|
||||||
DiscourseAi::NSFW::EntryPoint.new,
|
DiscourseAi::NSFW::EntryPoint.new,
|
||||||
DiscourseAi::Toxicity::EntryPoint.new,
|
DiscourseAi::Toxicity::EntryPoint.new,
|
||||||
DiscourseAi::Sentiment::EntryPoint.new,
|
DiscourseAi::Sentiment::EntryPoint.new,
|
||||||
|
|
Loading…
Reference in New Issue