FEATURE: Add BGE-M3 embeddings support (#569)

BAAI/bge-m3 is an interesting model, that is multilingual and with a
context size of 8192. Even with a 16x larger context, it's only 4x slower
to compute it's embeddings on the worst case scenario.

Also includes a minor refactor of the rake task, including setting model
and concurrency levels when running the backfill task.
This commit is contained in:
Rafael dos Santos Silva 2024-04-10 17:24:01 -03:00 committed by GitHub
parent 6de9c53a71
commit eb93b21769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1000347 additions and 19 deletions

View File

@ -264,6 +264,7 @@ discourse_ai:
- multilingual-e5-large - multilingual-e5-large
- bge-large-en - bge-large-en
- gemini - gemini
- bge-m3
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator" validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
ai_embeddings_per_post_enabled: ai_embeddings_per_post_enabled:
default: false default: false

View File

@ -0,0 +1,38 @@
# frozen_string_literal: true
class AddEmbeddingsTablesforBgeM3 < ActiveRecord::Migration[7.0]
def change
create_table :ai_topic_embeddings_8_1, id: false do |t|
t.integer :topic_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :topic_id, unique: true
end
create_table :ai_post_embeddings_8_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_document_fragment_embeddings_8_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_8_1"
end
end
end

View File

@ -11,11 +11,12 @@ module DiscourseAi
[ [
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
DiscourseAi::Embeddings::VectorRepresentations::Gemini, DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
].find { _1.name == model_name } ].find { _1.name == model_name }
end end

View File

@ -0,0 +1,56 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeM3 < Base
class << self
def name
"bge-m3"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
end
def dependant_setting_names
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
end
end
def vector_from(text, asymetric: false)
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
end
def dimensions
1024
end
def max_sequence_length
8192
end
def id
8
end
def version
1
end
def pg_function
"<#>"
end
def pg_index_type
"vector_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer
end
end
end
end
end

View File

@ -1,35 +1,49 @@
# frozen_string_literal: true # frozen_string_literal: true
desc "Backfill embeddings for all topics and posts" desc "Backfill embeddings for all topics and posts"
task "ai:embeddings:backfill" => [:environment] do task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id) public_categories = Category.where(read_restricted: false).pluck(:id)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) if args[:model].present?
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
strategy,
)
else
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
end
table_name = vector_rep.topic_table_name table_name = vector_rep.topic_table_name
topics =
Topic Topic
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
.where("#{table_name}.topic_id IS NULL") .where("#{table_name}.topic_id IS NULL")
.where("category_id IN (?)", public_categories) .where("category_id IN (?)", public_categories)
.where(deleted_at: nil) .where(deleted_at: nil)
.order("topics.id DESC") .order("topics.id DESC")
.find_each do |t|
print "." Parallel.each(topics.all, in_processes: args[:concurrency].to_i, progress: "Topics") do |t|
ActiveRecord::Base.connection_pool.with_connection do
vector_rep.generate_representation_from(t) vector_rep.generate_representation_from(t)
end end
end
table_name = vector_rep.post_table_name table_name = vector_rep.post_table_name
posts =
Post Post
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id") .joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
.where("#{table_name}.post_id IS NULL") .where("#{table_name}.post_id IS NULL")
.where(deleted_at: nil) .where(deleted_at: nil)
.order("posts.id DESC") .order("posts.id DESC")
.find_each do |t|
print "." Parallel.each(posts.all, in_processes: args[:concurrency].to_i, progress: "Posts") do |t|
ActiveRecord::Base.connection_pool.with_connection do
vector_rep.generate_representation_from(t) vector_rep.generate_representation_from(t)
end end
end end
end
desc "Creates indexes for embeddings" desc "Creates indexes for embeddings"
task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args| task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|

View File

@ -0,0 +1,11 @@
# frozen_string_literal: true
module DiscourseAi
module Tokenizer
class BgeM3Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-m3.json")
end
end
end
end

View File

@ -176,3 +176,32 @@ describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
end end
end end
end end
describe DiscourseAi::Tokenizer::BgeM3Tokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(7)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo")
end
it "truncates a sentence successfully at a multibyte unicode character" do
sentence = "foo bar 👨🏿‍👩🏿‍👧🏿‍👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
end
it "truncates unicode characters properly when they use more than one token per char" do
sentence = "我喜欢吃比萨"
original_size = described_class.size(sentence)
expect(described_class.size(described_class.truncate(sentence, original_size - 2))).to be <
original_size
end
end
end

View File

@ -25,3 +25,7 @@ Licensed under MIT License
## mixtral ## mixtral
Licensed under Apache 2.0 License Licensed under Apache 2.0 License
## bge-m3
Licensed under MIT License

1000174
tokenizers/bge-m3.json Normal file

File diff suppressed because one or more lines are too long