2023-03-15 16:21:45 -04:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
desc "Backfill embeddings for all topics"
|
2023-05-17 19:21:28 -04:00
|
|
|
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
2023-03-15 16:21:45 -04:00
|
|
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
2023-07-13 17:59:25 -04:00
|
|
|
manager = DiscourseAi::Embeddings::Manager.new(Topic.first)
|
2023-09-05 10:08:23 -04:00
|
|
|
|
|
|
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
|
|
|
vector_rep =
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::Base.find_vector_representation.new(strategy)
|
|
|
|
table_name = vector_rep.table_name
|
|
|
|
|
2023-03-15 16:21:45 -04:00
|
|
|
Topic
|
2023-09-05 10:08:23 -04:00
|
|
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
|
|
|
.where("#{table_name}.topic_id IS NULL")
|
2023-07-13 17:59:25 -04:00
|
|
|
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
2023-03-31 14:29:56 -04:00
|
|
|
.where("category_id IN (?)", public_categories)
|
2023-03-15 16:21:45 -04:00
|
|
|
.where(deleted_at: nil)
|
2023-07-13 17:59:25 -04:00
|
|
|
.order("topics.id ASC")
|
2023-03-15 16:21:45 -04:00
|
|
|
.find_each do |t|
|
|
|
|
print "."
|
2023-09-05 10:08:23 -04:00
|
|
|
vector_rep.generate_topic_representation_from(t)
|
2023-03-15 16:21:45 -04:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
desc "Creates indexes for embeddings"
|
2023-03-20 15:44:55 -04:00
|
|
|
task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
2023-05-09 12:45:16 -04:00
|
|
|
# Using extension maintainer's recommendation for ivfflat indexes
|
2023-03-15 16:21:45 -04:00
|
|
|
# Results are not as good as without indexes, but it's much faster
|
2023-07-13 17:59:25 -04:00
|
|
|
# Disk usage is ~1x the size of the table, so this doubles table total size
|
2023-05-09 12:45:16 -04:00
|
|
|
count = Topic.count
|
|
|
|
lists = count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i
|
|
|
|
probes = count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i
|
2023-03-15 16:21:45 -04:00
|
|
|
|
2023-09-05 10:08:23 -04:00
|
|
|
vector_representation_klass = DiscourseAi::Embeddings::Vectors::Base.find_vector_representation
|
|
|
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
2023-07-13 17:59:25 -04:00
|
|
|
|
|
|
|
DB.exec("SET work_mem TO '#{args[:work_mem] || "1GB"}';")
|
2023-09-05 10:08:23 -04:00
|
|
|
vector_representation_klass.new(strategy).create_index(lists, probes)
|
2023-07-13 17:59:25 -04:00
|
|
|
DB.exec("RESET work_mem;")
|
|
|
|
DB.exec("SET ivfflat.probes = #{probes};")
|
2023-03-15 16:21:45 -04:00
|
|
|
end
|