Store and use
This commit is contained in:
parent
9e4a007a4a
commit
70fc240fb1
|
@ -102,4 +102,7 @@ plugins:
|
|||
- all-mpnet-base-v2
|
||||
- msmarco-distilbert-base-v4
|
||||
- text-embedding-ada-002
|
||||
ai_embeddings_semantic_suggested_topics_anons_enabled: false
|
||||
ai_embeddings_pg_connection_string:
|
||||
default: "postgresql://localhost/embeddings"
|
||||
|
|
@ -6,6 +6,7 @@ module DiscourseAI
|
|||
def load_files
|
||||
require_relative "topic"
|
||||
require_relative "jobs/regular/generate_embeddings"
|
||||
require_relative "semantic_suggested"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
|
@ -18,6 +19,8 @@ module DiscourseAI
|
|||
|
||||
plugin.on(:topic_created, &callback)
|
||||
plugin.on(:topic_edited, &callback)
|
||||
|
||||
DiscoursePluginRegistry.register_list_suggested_for_provider(SemanticSuggested.method(:build_suggested_topics), plugin)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,7 +6,7 @@ module Jobs
|
|||
return unless SiteSetting.ai_embeddings_enabled
|
||||
return if (topic_id = args[:topic_id]).blank?
|
||||
|
||||
post = Topic.find_by_id(post_id).first_post
|
||||
post = Topic.find_by_id(topic_id).first_post
|
||||
return if post.nil? || post.raw.blank?
|
||||
|
||||
DiscourseAI::Embeddings::Topic.new(post.topic).perform!
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAI
|
||||
module Embeddings
|
||||
class SemanticSuggested
|
||||
def self.build_suggested_topics(topic, pm_params, topic_query)
|
||||
return unless SiteSetting.ai_embeddings_semantic_suggested_topics_anons_enabled
|
||||
return if topic_query.user
|
||||
return if topic.private_message?
|
||||
|
||||
candidate_ids = DiscourseAI::Database::Connection.db.query(<<~SQL, topic_id: topic.id)
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
topic_embeddings_symetric_discourse
|
||||
WHERE
|
||||
topic_id != :topic_id
|
||||
ORDER BY
|
||||
embeddings <#> (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
topic_embeddings_symetric_discourse
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 10
|
||||
SQL
|
||||
|
||||
candidates = ::Topic.where(id: candidate_ids.map(&:topic_id))
|
||||
{ result: candidates, params: {} }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,8 +3,8 @@
|
|||
module DiscourseAI
|
||||
module Embeddings
|
||||
class Topic
|
||||
DISCOURSE_MODELS = %i[all-mpnet-base-v2 msmarco-distilbert-base-v4]
|
||||
OPENAI_MODELS = %i[text-embedding-ada-002]
|
||||
DISCOURSE_MODELS = %w[all-mpnet-base-v2 msmarco-distilbert-base-v4]
|
||||
OPENAI_MODELS = %w[text-embedding-ada-002]
|
||||
|
||||
def initialize(topic)
|
||||
@topic = topic
|
||||
|
@ -33,8 +33,26 @@ module DiscourseAI
|
|||
end
|
||||
|
||||
def persist_embeddings!
|
||||
pp @embeddings
|
||||
#TODO: persist embeddings
|
||||
return if @embeddings["all-mpnet-base-v2"].blank?
|
||||
@embeddings.each do |model, model_embeddings|
|
||||
case model
|
||||
when "all-mpnet-base-v2"
|
||||
DiscourseAI::Database::Connection.db.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO topic_embeddings_symetric_discourse (topic_id, embeddings)
|
||||
VALUES (:topic_id, '[:embeddings]')
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET embeddings = '[:embeddings]'
|
||||
SQL
|
||||
topic_id: @topic.id,
|
||||
embeddings: model_embeddings,
|
||||
)
|
||||
when "msmarco-distilbert-base-v4"
|
||||
#todo
|
||||
when "text-embedding-ada-002"
|
||||
#todo
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def discourse_embeddings(model)
|
||||
|
@ -53,7 +71,7 @@ module DiscourseAI
|
|||
private
|
||||
|
||||
def enabled_models
|
||||
SiteSetting.ai_embeddings_models.split("|").map(&:to_sym)
|
||||
SiteSetting.ai_embeddings_models.split("|")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Database
|
||||
class Connection
|
||||
def self.connect!
|
||||
pg_conn = PG.connect(SiteSetting.ai_embeddings_pg_connection_string)
|
||||
@@db = MiniSql::Connection.get(pg_conn)
|
||||
end
|
||||
|
||||
def self.db
|
||||
@@db ||= connect!
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -25,6 +25,8 @@ after_initialize do
|
|||
require_relative "lib/shared/post_classificator"
|
||||
require_relative "lib/shared/chat_message_classificator"
|
||||
|
||||
require_relative "lib/shared/database/connection"
|
||||
|
||||
require_relative "lib/modules/nsfw/entry_point"
|
||||
require_relative "lib/modules/toxicity/entry_point"
|
||||
require_relative "lib/modules/sentiment/entry_point"
|
||||
|
|
Loading…
Reference in New Issue