FIX: Fix embeddings truncation strategy (#139)
This commit is contained in:
parent
525c8b0913
commit
0738f67fa4
|
@ -22,17 +22,17 @@ module DiscourseAi
|
|||
@model = model
|
||||
@target = target
|
||||
@tokenizer = @model.tokenizer
|
||||
@max_length = @model.max_sequence_length
|
||||
@processed_target = +""
|
||||
@max_length = @model.max_sequence_length - 2
|
||||
@processed_target = nil
|
||||
end
|
||||
|
||||
# Need a better name for this method
|
||||
def process!
|
||||
case @target
|
||||
when Topic
|
||||
topic_truncation(@target)
|
||||
@processed_target = topic_truncation(@target)
|
||||
when Post
|
||||
post_truncation(@target)
|
||||
@processed_target = post_truncation(@target)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
|
@ -41,7 +41,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def topic_truncation(topic)
|
||||
t = @processed_target
|
||||
t = +""
|
||||
|
||||
t << topic.title
|
||||
t << "\n\n"
|
||||
|
@ -54,7 +54,7 @@ module DiscourseAi
|
|||
|
||||
topic.posts.find_each do |post|
|
||||
t << post.raw
|
||||
break if @tokenizer.size(t) >= @max_length
|
||||
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
|
||||
t << "\n\n"
|
||||
end
|
||||
|
||||
|
@ -62,7 +62,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def post_truncation(post)
|
||||
t = processed_target
|
||||
t = +""
|
||||
|
||||
t << post.topic.title
|
||||
t << "\n\n"
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||
describe "#process!" do
|
||||
context "when the model uses OpenAI to create embeddings" do
|
||||
before { SiteSetting.max_post_length = 100_000 }
|
||||
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) do
|
||||
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
|
||||
end
|
||||
fab!(:post) do
|
||||
Fabricate(
|
||||
:post,
|
||||
topic: topic,
|
||||
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
|
||||
)
|
||||
end
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||
|
||||
let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
|
||||
let(:truncation) { described_class.new(topic, model) }
|
||||
|
||||
it "truncates a topic" do
|
||||
truncation.process!
|
||||
|
||||
expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue