FIX: Fix embeddings truncation strategy (#139)

This commit is contained in:
Rafael dos Santos Silva 2023-08-16 15:09:41 -03:00 committed by GitHub
parent 525c8b0913
commit 0738f67fa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 7 deletions

View File

@ -22,17 +22,17 @@ module DiscourseAi
@model = model
@target = target
@tokenizer = @model.tokenizer
@max_length = @model.max_sequence_length
@processed_target = +""
@max_length = @model.max_sequence_length - 2
@processed_target = nil
end
# Need a better name for this method
def process!
case @target
when Topic
topic_truncation(@target)
@processed_target = topic_truncation(@target)
when Post
post_truncation(@target)
@processed_target = post_truncation(@target)
else
raise ArgumentError, "Invalid target type"
end
@ -41,7 +41,7 @@ module DiscourseAi
end
def topic_truncation(topic)
t = @processed_target
t = +""
t << topic.title
t << "\n\n"
@ -54,7 +54,7 @@ module DiscourseAi
topic.posts.find_each do |post|
t << post.raw
break if @tokenizer.size(t) >= @max_length
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
t << "\n\n"
end
@ -62,7 +62,7 @@ module DiscourseAi
end
def post_truncation(post)
t = processed_target
t = +""
t << post.topic.title
t << "\n\n"

View File

@ -0,0 +1,31 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
describe "#process!" do
context "when the model uses OpenAI to create embeddings" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:topic) { Fabricate(:topic) }
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
let(:truncation) { described_class.new(topic, model) }
it "truncates a topic" do
truncation.process!
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
end
end
end
end