FEATURE: Add BGE-M3 embeddings support (#569)
BAAI/bge-m3 is an interesting model, that is multilingual and with a context size of 8192. Even with a 16x larger context, it's only 4x slower to compute it's embeddings on the worst case scenario. Also includes a minor refactor of the rake task, including setting model and concurrency levels when running the backfill task.
This commit is contained in:
parent
6de9c53a71
commit
eb93b21769
|
@ -264,6 +264,7 @@ discourse_ai:
|
||||||
- multilingual-e5-large
|
- multilingual-e5-large
|
||||||
- bge-large-en
|
- bge-large-en
|
||||||
- gemini
|
- gemini
|
||||||
|
- bge-m3
|
||||||
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
||||||
ai_embeddings_per_post_enabled:
|
ai_embeddings_per_post_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AddEmbeddingsTablesforBgeM3 < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
create_table :ai_topic_embeddings_8_1, id: false do |t|
|
||||||
|
t.integer :topic_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1024)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :topic_id, unique: true
|
||||||
|
end
|
||||||
|
create_table :ai_post_embeddings_8_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1024)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
create_table :ai_document_fragment_embeddings_8_1, id: false do |t|
|
||||||
|
t.integer :rag_document_fragment_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1024)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :rag_document_fragment_id,
|
||||||
|
unique: true,
|
||||||
|
name: "rag_document_fragment_id_embeddings_8_1"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -11,11 +11,12 @@ module DiscourseAi
|
||||||
[
|
[
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||||
].find { _1.name == model_name }
|
].find { _1.name == model_name }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class BgeM3 < Base
|
||||||
|
class << self
|
||||||
|
def name
|
||||||
|
"bge-m3"
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text, asymetric: false)
|
||||||
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1024
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
8192
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
8
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<#>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_ip_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,35 +1,49 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
desc "Backfill embeddings for all topics and posts"
|
desc "Backfill embeddings for all topics and posts"
|
||||||
task "ai:embeddings:backfill" => [:environment] do
|
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
|
||||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
|
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
if args[:model].present?
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
|
||||||
|
strategy,
|
||||||
|
)
|
||||||
|
else
|
||||||
|
vector_rep =
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
end
|
||||||
table_name = vector_rep.topic_table_name
|
table_name = vector_rep.topic_table_name
|
||||||
|
|
||||||
|
topics =
|
||||||
Topic
|
Topic
|
||||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||||
.where("#{table_name}.topic_id IS NULL")
|
.where("#{table_name}.topic_id IS NULL")
|
||||||
.where("category_id IN (?)", public_categories)
|
.where("category_id IN (?)", public_categories)
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.order("topics.id DESC")
|
.order("topics.id DESC")
|
||||||
.find_each do |t|
|
|
||||||
print "."
|
Parallel.each(topics.all, in_processes: args[:concurrency].to_i, progress: "Topics") do |t|
|
||||||
|
ActiveRecord::Base.connection_pool.with_connection do
|
||||||
vector_rep.generate_representation_from(t)
|
vector_rep.generate_representation_from(t)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
table_name = vector_rep.post_table_name
|
table_name = vector_rep.post_table_name
|
||||||
|
posts =
|
||||||
Post
|
Post
|
||||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||||
.where("#{table_name}.post_id IS NULL")
|
.where("#{table_name}.post_id IS NULL")
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.order("posts.id DESC")
|
.order("posts.id DESC")
|
||||||
.find_each do |t|
|
|
||||||
print "."
|
Parallel.each(posts.all, in_processes: args[:concurrency].to_i, progress: "Posts") do |t|
|
||||||
|
ActiveRecord::Base.connection_pool.with_connection do
|
||||||
vector_rep.generate_representation_from(t)
|
vector_rep.generate_representation_from(t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
desc "Creates indexes for embeddings"
|
desc "Creates indexes for embeddings"
|
||||||
task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
task "ai:embeddings:index", [:work_mem] => [:environment] do |_, args|
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class BgeM3Tokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-m3.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -176,3 +176,32 @@ describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe DiscourseAi::Tokenizer::BgeM3Tokenizer do
|
||||||
|
describe "#size" do
|
||||||
|
describe "returns a token count" do
|
||||||
|
it "for a sentence with punctuation and capitalization and numbers" do
|
||||||
|
expect(described_class.size("Hello, World! 123")).to eq(7)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#truncate" do
|
||||||
|
it "truncates a sentence" do
|
||||||
|
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||||
|
expect(described_class.truncate(sentence, 3)).to eq("foo")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "truncates a sentence successfully at a multibyte unicode character" do
|
||||||
|
sentence = "foo bar 👨🏿👩🏿👧🏿👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||||
|
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "truncates unicode characters properly when they use more than one token per char" do
|
||||||
|
sentence = "我喜欢吃比萨"
|
||||||
|
original_size = described_class.size(sentence)
|
||||||
|
expect(described_class.size(described_class.truncate(sentence, original_size - 2))).to be <
|
||||||
|
original_size
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
|
@ -25,3 +25,7 @@ Licensed under MIT License
|
||||||
## mixtral
|
## mixtral
|
||||||
|
|
||||||
Licensed under Apache 2.0 License
|
Licensed under Apache 2.0 License
|
||||||
|
|
||||||
|
## bge-m3
|
||||||
|
|
||||||
|
Licensed under MIT License
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue