From 70fc240fb1a27c74ec8fece1dcf2f8c4de58f09e Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Mon, 13 Mar 2023 19:05:25 -0300 Subject: [PATCH] Store and use --- config/settings.yml | 3 ++ lib/modules/embeddings/entry_point.rb | 3 ++ .../jobs/regular/generate_embeddings.rb | 2 +- lib/modules/embeddings/semantic_suggested.rb | 36 +++++++++++++++++++ lib/modules/embeddings/topic.rb | 28 ++++++++++++--- lib/shared/database/connection.rb | 16 +++++++++ plugin.rb | 2 ++ 7 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 lib/modules/embeddings/semantic_suggested.rb create mode 100644 lib/shared/database/connection.rb diff --git a/config/settings.yml b/config/settings.yml index 20f0d4c2..82239c4a 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -102,4 +102,7 @@ plugins: - all-mpnet-base-v2 - msmarco-distilbert-base-v4 - text-embedding-ada-002 + ai_embeddings_semantic_suggested_topics_anons_enabled: false + ai_embeddings_pg_connection_string: + default: "postgresql://localhost/embeddings" \ No newline at end of file diff --git a/lib/modules/embeddings/entry_point.rb b/lib/modules/embeddings/entry_point.rb index f698f310..8b37b1ea 100644 --- a/lib/modules/embeddings/entry_point.rb +++ b/lib/modules/embeddings/entry_point.rb @@ -6,6 +6,7 @@ module DiscourseAI def load_files require_relative "topic" require_relative "jobs/regular/generate_embeddings" + require_relative "semantic_suggested" end def inject_into(plugin) @@ -18,6 +19,8 @@ module DiscourseAI plugin.on(:topic_created, &callback) plugin.on(:topic_edited, &callback) + + DiscoursePluginRegistry.register_list_suggested_for_provider(SemanticSuggested.method(:build_suggested_topics), plugin) end end end diff --git a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb index 19947b95..1a94b9b6 100644 --- a/lib/modules/embeddings/jobs/regular/generate_embeddings.rb +++ b/lib/modules/embeddings/jobs/regular/generate_embeddings.rb @@ -6,7 +6,7 @@ module Jobs return unless SiteSetting.ai_embeddings_enabled return if (topic_id = args[:topic_id]).blank? - post = Topic.find_by_id(post_id).first_post + post = Topic.find_by_id(topic_id).first_post return if post.nil? || post.raw.blank? DiscourseAI::Embeddings::Topic.new(post.topic).perform! diff --git a/lib/modules/embeddings/semantic_suggested.rb b/lib/modules/embeddings/semantic_suggested.rb new file mode 100644 index 00000000..742d6b79 --- /dev/null +++ b/lib/modules/embeddings/semantic_suggested.rb @@ -0,0 +1,36 @@ +# frozen_string_literal: true + +module DiscourseAI + module Embeddings + class SemanticSuggested + def self.build_suggested_topics(topic, pm_params, topic_query) + return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled + return if topic_query.user + return if topic.private_message? + + candidate_ids = DiscourseAI::Database::Connection.db.query(<<~SQL, topic_id: topic.id) + SELECT + topic_id + FROM + topic_embeddings_symetric_discourse + WHERE + topic_id != :topic_id + ORDER BY + embeddings <#> ( + SELECT + embeddings + FROM + topic_embeddings_symetric_discourse + WHERE + topic_id = :topic_id + LIMIT 1 + ) + LIMIT 10 + SQL + + candidates = ::Topic.where(id: candidate_ids.map(&:topic_id)) + { result: candidates, params: {} } + end + end + end +end diff --git a/lib/modules/embeddings/topic.rb b/lib/modules/embeddings/topic.rb index 39e6ff85..2b0eadb0 100644 --- a/lib/modules/embeddings/topic.rb +++ b/lib/modules/embeddings/topic.rb @@ -3,8 +3,8 @@ module DiscourseAI module Embeddings class Topic - DISCOURSE_MODELS = %i[all-mpnet-base-v2 msmarco-distilbert-base-v4] - OPENAI_MODELS = %i[text-embedding-ada-002] + DISCOURSE_MODELS = %w[all-mpnet-base-v2 msmarco-distilbert-base-v4] + OPENAI_MODELS = %w[text-embedding-ada-002] def initialize(topic) @topic = topic @@ -33,8 +33,26 @@ module DiscourseAI end def persist_embeddings! - pp @embeddings - #TODO: persist embeddings + return if @embeddings["all-mpnet-base-v2"].blank? + @embeddings.each do |model, model_embeddings| + case model + when "all-mpnet-base-v2" + DiscourseAI::Database::Connection.db.exec( + <<~SQL, + INSERT INTO topic_embeddings_symetric_discourse (topic_id, embeddings) + VALUES (:topic_id, '[:embeddings]') + ON CONFLICT (topic_id) + DO UPDATE SET embeddings = '[:embeddings]' + SQL + topic_id: @topic.id, + embeddings: model_embeddings, + ) + when "msmarco-distilbert-base-v4" + #todo + when "text-embedding-ada-002" + #todo + end + end end def discourse_embeddings(model) @@ -53,7 +71,7 @@ module DiscourseAI private def enabled_models - SiteSetting.ai_embeddings_models.split("|").map(&:to_sym) + SiteSetting.ai_embeddings_models.split("|") end end end diff --git a/lib/shared/database/connection.rb b/lib/shared/database/connection.rb new file mode 100644 index 00000000..b48d0fc4 --- /dev/null +++ b/lib/shared/database/connection.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +module ::DiscourseAI + module Database + class Connection + def self.connect! + pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string) + @@db = MiniSql::Connection.get(pg_conn) + end + + def self.db + @@db ||= connect! + end + end + end +end diff --git a/plugin.rb b/plugin.rb index 35bcef36..47c5f47c 100644 --- a/plugin.rb +++ b/plugin.rb @@ -25,6 +25,8 @@ after_initialize do require_relative "lib/shared/post_classificator" require_relative "lib/shared/chat_message_classificator" + require_relative "lib/shared/database/connection" + require_relative "lib/modules/nsfw/entry_point" require_relative "lib/modules/toxicity/entry_point" require_relative "lib/modules/sentiment/entry_point"