FEATURE: Add BGE-M3 embeddings support (#569)
BAAI/bge-m3 is an interesting model, that is multilingual and with a context size of 8192. Even with a 16x larger context, it's only 4x slower to compute it's embeddings on the worst case scenario. Also includes a minor refactor of the rake task, including setting model and concurrency levels when running the backfill task.
This commit is contained in:
parent
6de9c53a71
commit
eb93b21769
|
@ -264,6 +264,7 @@ discourse_ai:
|
|||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
- gemini
|
||||
- bge-m3
|
||||
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
||||
ai_embeddings_per_post_enabled:
|
||||
default: false
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AddEmbeddingsTablesforBgeM3 < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :ai_topic_embeddings_8_1, id: false do |t|
|
||||
t.integer :topic_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1024)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :topic_id, unique: true
|
||||
end
|
||||
create_table :ai_post_embeddings_8_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1024)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
create_table :ai_document_fragment_embeddings_8_1, id: false do |t|
|
||||
t.integer :rag_document_fragment_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1024)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :rag_document_fragment_id,
|
||||
unique: true,
|
||||
name: "rag_document_fragment_id_embeddings_8_1"
|
||||
end
|
||||
end
|
||||
end
|
|
@ -11,11 +11,12 @@ module DiscourseAi
|
|||
[
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||
].find { _1.name == model_name }
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class BgeM3 < Base
|
||||
class << self
|
||||
def name
|
||||
"bge-m3"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
|
||||
end
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8192
|
||||
end
|
||||
|
||||
def id
|
||||
8
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,34 +1,48 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
desc "Backfill embeddings for all topics and posts"
|
||||
task "ai:embeddings:backfill" => [:environment] do
|
||||
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
if args[:model].present?
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
|
||||
strategy,
|
||||
)
|
||||
else
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
end
|
||||
table_name = vector_rep.topic_table_name
|
||||
|
||||
Topic
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||
.where("#{table_name}.topic_id IS NULL")
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.order("topics.id DESC")
|
||||
.find_each do |t|
|
||||
print "."
|
||||
topics =
|
||||
Topic
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||
.where("#{table_name}.topic_id IS NULL")
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.order("topics.id DESC")
|
||||
|
||||
Parallel.each(topics.all, in_processes: args[:concurrency].to_i, progress: "Topics") do |t|
|
||||
ActiveRecord::Base.connection_pool.with_connection do
|
||||
vector_rep.generate_representation_from(t)
|
||||
end
|
||||
end
|
||||
|
||||
table_name = vector_rep.post_table_name
|
||||
Post
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||
.where("#{table_name}.post_id IS NULL")
|
||||
.where(deleted_at: nil)
|
||||
.order("posts.id DESC")
|
||||
.find_each do |t|
|
||||
print "."
|
||||
posts =
|
||||
Post
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||
.where("#{table_name}.post_id IS NULL")
|
||||
.where(deleted_at: nil)
|
||||
.order("posts.id DESC")
|
||||
|
||||
Parallel.each(posts.all, in_processes: args[:concurrency].to_i, progress: "Posts") do |t|
|
||||
ActiveRecord::Base.connection_pool.with_connection do
|
||||
vector_rep.generate_representation_from(t)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
desc "Creates indexes for embeddings"
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Tokenizer
|
||||
class BgeM3Tokenizer < BasicTokenizer
|
||||
def self.tokenizer
|
||||
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-m3.json")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -176,3 +176,32 @@ describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::Tokenizer::BgeM3Tokenizer do
|
||||
describe "#size" do
|
||||
describe "returns a token count" do
|
||||
it "for a sentence with punctuation and capitalization and numbers" do
|
||||
expect(described_class.size("Hello, World! 123")).to eq(7)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#truncate" do
|
||||
it "truncates a sentence" do
|
||||
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||
expect(described_class.truncate(sentence, 3)).to eq("foo")
|
||||
end
|
||||
|
||||
it "truncates a sentence successfully at a multibyte unicode character" do
|
||||
sentence = "foo bar 👨🏿👩🏿👧🏿👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
|
||||
end
|
||||
|
||||
it "truncates unicode characters properly when they use more than one token per char" do
|
||||
sentence = "我喜欢吃比萨"
|
||||
original_size = described_class.size(sentence)
|
||||
expect(described_class.size(described_class.truncate(sentence, original_size - 2))).to be <
|
||||
original_size
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -25,3 +25,7 @@ Licensed under MIT License
|
|||
## mixtral
|
||||
|
||||
Licensed under Apache 2.0 License
|
||||
|
||||
## bge-m3
|
||||
|
||||
Licensed under MIT License
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue