2023-03-15 16:21:45 -04:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
2023-12-29 10:28:45 -05:00
|
|
|
desc "Backfill embeddings for all topics and posts"
|
2024-04-10 16:24:01 -04:00
|
|
|
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
|
2023-03-15 16:21:45 -04:00
|
|
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
2023-09-05 10:08:23 -04:00
|
|
|
|
2024-04-10 16:24:01 -04:00
|
|
|
if args[:model].present?
|
2024-12-13 08:15:21 -05:00
|
|
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
2024-04-10 16:24:01 -04:00
|
|
|
vector_rep =
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
|
|
|
|
strategy,
|
|
|
|
)
|
|
|
|
else
|
2024-12-13 08:15:21 -05:00
|
|
|
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
2024-04-10 16:24:01 -04:00
|
|
|
end
|
2024-12-13 08:15:21 -05:00
|
|
|
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
2023-09-05 10:08:23 -04:00
|
|
|
|
2024-04-10 16:24:01 -04:00
|
|
|
topics =
|
|
|
|
Topic
|
|
|
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
|
|
|
.where("#{table_name}.topic_id IS NULL")
|
|
|
|
.where("category_id IN (?)", public_categories)
|
|
|
|
.where(deleted_at: nil)
|
|
|
|
.order("topics.id DESC")
|
|
|
|
|
|
|
|
Parallel.each(topics.all, in_processes: args[:concurrency].to_i, progress: "Topics") do |t|
|
|
|
|
ActiveRecord::Base.connection_pool.with_connection do
|
2023-12-29 10:28:45 -05:00
|
|
|
vector_rep.generate_representation_from(t)
|
|
|
|
end
|
2024-04-10 16:24:01 -04:00
|
|
|
end
|
2023-12-29 10:28:45 -05:00
|
|
|
|
|
|
|
table_name = vector_rep.post_table_name
|
2024-04-10 16:24:01 -04:00
|
|
|
posts =
|
|
|
|
Post
|
|
|
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
|
|
|
.where("#{table_name}.post_id IS NULL")
|
|
|
|
.where(deleted_at: nil)
|
|
|
|
.order("posts.id DESC")
|
|
|
|
|
|
|
|
Parallel.each(posts.all, in_processes: args[:concurrency].to_i, progress: "Posts") do |t|
|
|
|
|
ActiveRecord::Base.connection_pool.with_connection do
|
2023-12-29 10:28:45 -05:00
|
|
|
vector_rep.generate_representation_from(t)
|
2023-03-15 16:21:45 -04:00
|
|
|
end
|
2024-04-10 16:24:01 -04:00
|
|
|
end
|
2023-03-15 16:21:45 -04:00
|
|
|
end
|