2023-07-13 12:41:36 -03:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Embeddings
|
|
|
|
module Strategies
|
|
|
|
class Truncation
|
|
|
|
attr_reader :processed_target, :digest
|
|
|
|
|
|
|
|
def self.id
|
|
|
|
1
|
|
|
|
end
|
|
|
|
|
|
|
|
def id
|
|
|
|
self.class.id
|
|
|
|
end
|
|
|
|
|
|
|
|
def version
|
|
|
|
1
|
|
|
|
end
|
|
|
|
|
|
|
|
def initialize(target, model)
|
|
|
|
@model = model
|
|
|
|
@target = target
|
|
|
|
@tokenizer = @model.tokenizer
|
2023-08-16 15:09:41 -03:00
|
|
|
@max_length = @model.max_sequence_length - 2
|
|
|
|
@processed_target = nil
|
2023-07-13 12:41:36 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
# Need a better name for this method
|
|
|
|
def process!
|
|
|
|
case @target
|
|
|
|
when Topic
|
2023-08-16 15:09:41 -03:00
|
|
|
@processed_target = topic_truncation(@target)
|
2023-07-13 12:41:36 -03:00
|
|
|
when Post
|
2023-08-16 15:09:41 -03:00
|
|
|
@processed_target = post_truncation(@target)
|
2023-07-13 12:41:36 -03:00
|
|
|
else
|
|
|
|
raise ArgumentError, "Invalid target type"
|
|
|
|
end
|
|
|
|
|
|
|
|
@digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target)
|
|
|
|
end
|
|
|
|
|
|
|
|
def topic_truncation(topic)
|
2023-08-16 15:09:41 -03:00
|
|
|
t = +""
|
2023-07-13 12:41:36 -03:00
|
|
|
|
|
|
|
t << topic.title
|
|
|
|
t << "\n\n"
|
|
|
|
t << topic.category.name
|
|
|
|
if SiteSetting.tagging_enabled
|
|
|
|
t << "\n\n"
|
|
|
|
t << topic.tags.pluck(:name).join(", ")
|
|
|
|
end
|
|
|
|
t << "\n\n"
|
|
|
|
|
2023-07-13 18:59:25 -03:00
|
|
|
topic.posts.find_each do |post|
|
2023-07-13 12:41:36 -03:00
|
|
|
t << post.raw
|
2023-08-16 15:09:41 -03:00
|
|
|
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
|
2023-07-13 12:41:36 -03:00
|
|
|
t << "\n\n"
|
|
|
|
end
|
|
|
|
|
|
|
|
@tokenizer.truncate(t, @max_length)
|
|
|
|
end
|
|
|
|
|
|
|
|
def post_truncation(post)
|
2023-08-16 15:09:41 -03:00
|
|
|
t = +""
|
2023-07-13 12:41:36 -03:00
|
|
|
|
|
|
|
t << post.topic.title
|
|
|
|
t << "\n\n"
|
|
|
|
t << post.topic.category.name
|
|
|
|
if SiteSetting.tagging_enabled
|
|
|
|
t << "\n\n"
|
|
|
|
t << post.topic.tags.pluck(:name).join(", ")
|
|
|
|
end
|
|
|
|
t << "\n\n"
|
|
|
|
t << post.raw
|
|
|
|
|
|
|
|
@tokenizer.truncate(t, @max_length)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|