Store and use

This commit is contained in:
Rafael dos Santos Silva 2023-03-13 19:05:25 -03:00
parent 9e4a007a4a
commit 70fc240fb1
No known key found for this signature in database
GPG Key ID: 5E50360227B34938
7 changed files with 84 additions and 6 deletions

View File

@ -102,4 +102,7 @@ plugins:
- all-mpnet-base-v2
- msmarco-distilbert-base-v4
- text-embedding-ada-002
ai_embeddings_semantic_suggested_topics_anons_enabled: false
ai_embeddings_pg_connection_string:
default: "postgresql://localhost/embeddings"

View File

@ -6,6 +6,7 @@ module DiscourseAI
def load_files
require_relative "topic"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_suggested"
end
def inject_into(plugin)
@ -18,6 +19,8 @@ module DiscourseAI
plugin.on(:topic_created, &callback)
plugin.on(:topic_edited, &callback)
DiscoursePluginRegistry.register_list_suggested_for_provider(SemanticSuggested.method(:build_suggested_topics), plugin)
end
end
end

View File

@ -6,7 +6,7 @@ module Jobs
return unless SiteSetting.ai_embeddings_enabled
return if (topic_id = args[:topic_id]).blank?
post = Topic.find_by_id(post_id).first_post
post = Topic.find_by_id(topic_id).first_post
return if post.nil? || post.raw.blank?
DiscourseAI::Embeddings::Topic.new(post.topic).perform!

View File

@ -0,0 +1,36 @@
# frozen_string_literal: true
module DiscourseAI
module Embeddings
class SemanticSuggested
def self.build_suggested_topics(topic, pm_params, topic_query)
return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled
return if topic_query.user
return if topic.private_message?
candidate_ids = DiscourseAI::Database::Connection.db.query(<<~SQL, topic_id: topic.id)
SELECT
topic_id
FROM
topic_embeddings_symetric_discourse
WHERE
topic_id != :topic_id
ORDER BY
embeddings <#> (
SELECT
embeddings
FROM
topic_embeddings_symetric_discourse
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 10
SQL
candidates = ::Topic.where(id: candidate_ids.map(&:topic_id))
{ result: candidates, params: {} }
end
end
end
end

View File

@ -3,8 +3,8 @@
module DiscourseAI
module Embeddings
class Topic
DISCOURSE_MODELS = %i[all-mpnet-base-v2 msmarco-distilbert-base-v4]
OPENAI_MODELS = %i[text-embedding-ada-002]
DISCOURSE_MODELS = %w[all-mpnet-base-v2 msmarco-distilbert-base-v4]
OPENAI_MODELS = %w[text-embedding-ada-002]
def initialize(topic)
@topic = topic
@ -33,8 +33,26 @@ module DiscourseAI
end
def persist_embeddings!
pp @embeddings
#TODO: persist embeddings
return if @embeddings["all-mpnet-base-v2"].blank?
@embeddings.each do |model, model_embeddings|
case model
when "all-mpnet-base-v2"
DiscourseAI::Database::Connection.db.exec(
<<~SQL,
INSERT INTO topic_embeddings_symetric_discourse (topic_id, embeddings)
VALUES (:topic_id, '[:embeddings]')
ON CONFLICT (topic_id)
DO UPDATE SET embeddings = '[:embeddings]'
SQL
topic_id: @topic.id,
embeddings: model_embeddings,
)
when "msmarco-distilbert-base-v4"
#todo
when "text-embedding-ada-002"
#todo
end
end
end
def discourse_embeddings(model)
@ -53,7 +71,7 @@ module DiscourseAI
private
def enabled_models
SiteSetting.ai_embeddings_models.split("|").map(&:to_sym)
SiteSetting.ai_embeddings_models.split("|")
end
end
end

View File

@ -0,0 +1,16 @@
# frozen_string_literal: true
module ::DiscourseAI
module Database
class Connection
def self.connect!
pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string)
@@db = MiniSql::Connection.get(pg_conn)
end
def self.db
@@db ||= connect!
end
end
end
end

View File

@ -25,6 +25,8 @@ after_initialize do
require_relative "lib/shared/post_classificator"
require_relative "lib/shared/chat_message_classificator"
require_relative "lib/shared/database/connection"
require_relative "lib/modules/nsfw/entry_point"
require_relative "lib/modules/toxicity/entry_point"
require_relative "lib/modules/sentiment/entry_point"