FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API (#241)

* FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API

* forgot a file

* lint
This commit is contained in:
Rafael dos Santos Silva 2023-10-04 13:47:51 -03:00 committed by GitHub
parent 05c256f65b
commit 84cc369552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 30810 additions and 1 deletions

View File

@ -82,6 +82,12 @@ en:
ai_google_custom_search_api_key: "API key for the Google Custom Search API see: https://developers.google.com/custom-search"
ai_google_custom_search_cx: "CX for Google Custom Search API"
ai_bedrock_access_key_id: "Access key ID for the Bedrock API"
ai_bedrock_secret_access_key: "Secret access key for the Bedrock API"
ai_bedrock_region: "AWS region for the Bedrock API"
ai_cloudflare_workers_account_id: "Cloudflare account ID for the Cloudflare Workers AI API"
ai_cloudflare_workers_api_token: "API token for the Cloudflare Workers AI API"
reviewables:
reasons:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
@ -122,7 +128,6 @@ en:
creative:
name: Creative
description: "AI Bot with no external integrations specialized in creative tasks"
default_pm_prefix: "[Untitled AI bot PM]"
topic_not_found: "Summary unavailable, topic not found!"
searching: "Searching for: '%{query}'"
command_summary:

View File

@ -133,6 +133,12 @@ discourse_ai:
secret: true
ai_bedrock_region:
default: "us-east-1"
ai_cloudflare_workers_account_id:
default: ""
secret: true
ai_cloudflare_workers_api_token:
default: ""
secret: true
composer_ai_helper_enabled:
default: false
@ -179,6 +185,7 @@ discourse_ai:
- all-mpnet-base-v2
- text-embedding-ada-002
- multilingual-e5-large
- bge-large-en
ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled:
default: false

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
class CreateBgeTopicEmbeddingsTable < ActiveRecord::Migration[7.0]
def change
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn.new(truncation)
create_table vector_rep.table_name.to_sym, id: false do |t|
t.integer :topic_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(#{vector_rep.dimensions})", null: false
t.timestamps
t.index :topic_id, unique: true
end
end
end

View File

@ -8,6 +8,7 @@ module DiscourseAi
require_relative "vector_representations/all_mpnet_base_v2"
require_relative "vector_representations/text_embedding_ada_002"
require_relative "vector_representations/multilingual_e5_large"
require_relative "vector_representations/bge_large_en"
require_relative "strategies/truncation"
require_relative "jobs/regular/generate_embeddings"
require_relative "semantic_related"

View File

@ -0,0 +1,52 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeLargeEn < Base
def vector_from(text)
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
end
def name
"bge-large-en"
end
def inference_model_name
"baai/bge-large-en-v1.5"
end
def dimensions
1024
end
def max_sequence_length
512
end
def id
4
end
def version
1
end
def pg_function
"<#>"
end
def pg_index_type
"vector_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
end
end
end
end
end

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class CloudflareWorkersAi
def self.perform!(model, content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
account_id = SiteSetting.ai_cloudflare_workers_account_id
token = SiteSetting.ai_cloudflare_workers_api_token
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
headers["Authorization"] = "Bearer #{token}"
endpoint = "#{base_url}#{model}"
response = Faraday.post(endpoint, content.to_json, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -66,6 +66,12 @@ module DiscourseAi
end
end
class BgeLargeEnTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-large-en.json")
end
end
class OpenAiTokenizer < BasicTokenizer
class << self
def tokenizer

View File

@ -35,6 +35,7 @@ after_initialize do
require_relative "lib/shared/inference/stability_generator"
require_relative "lib/shared/inference/hugging_face_text_generation"
require_relative "lib/shared/inference/amazon_bedrock_inference"
require_relative "lib/shared/inference/cloudflare_workers_ai"
require_relative "lib/shared/inference/function"
require_relative "lib/shared/inference/function_list"

View File

@ -134,3 +134,20 @@ describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
end
end
end
describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(7)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
end
end
end

View File

@ -17,3 +17,7 @@ Licensed under LLAMA 2 COMMUNITY LICENSE AGREEMENT
## multilingual-e5-large
Licensed under MIT License
## bge-large-en
Licensed under MIT License

30672
tokenizers/bge-large-en.json Normal file

File diff suppressed because it is too large Load Diff