FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API (#241)
* FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API * forgot a file * lint
This commit is contained in:
parent
05c256f65b
commit
84cc369552
|
@ -82,6 +82,12 @@ en:
|
|||
ai_google_custom_search_api_key: "API key for the Google Custom Search API see: https://developers.google.com/custom-search"
|
||||
ai_google_custom_search_cx: "CX for Google Custom Search API"
|
||||
|
||||
ai_bedrock_access_key_id: "Access key ID for the Bedrock API"
|
||||
ai_bedrock_secret_access_key: "Secret access key for the Bedrock API"
|
||||
ai_bedrock_region: "AWS region for the Bedrock API"
|
||||
ai_cloudflare_workers_account_id: "Cloudflare account ID for the Cloudflare Workers AI API"
|
||||
ai_cloudflare_workers_api_token: "API token for the Cloudflare Workers AI API"
|
||||
|
||||
reviewables:
|
||||
reasons:
|
||||
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
||||
|
@ -122,7 +128,6 @@ en:
|
|||
creative:
|
||||
name: Creative
|
||||
description: "AI Bot with no external integrations specialized in creative tasks"
|
||||
default_pm_prefix: "[Untitled AI bot PM]"
|
||||
topic_not_found: "Summary unavailable, topic not found!"
|
||||
searching: "Searching for: '%{query}'"
|
||||
command_summary:
|
||||
|
|
|
@ -133,6 +133,12 @@ discourse_ai:
|
|||
secret: true
|
||||
ai_bedrock_region:
|
||||
default: "us-east-1"
|
||||
ai_cloudflare_workers_account_id:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_cloudflare_workers_api_token:
|
||||
default: ""
|
||||
secret: true
|
||||
|
||||
composer_ai_helper_enabled:
|
||||
default: false
|
||||
|
@ -179,6 +185,7 @@ discourse_ai:
|
|||
- all-mpnet-base-v2
|
||||
- text-embedding-ada-002
|
||||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
ai_embeddings_generate_for_pms: false
|
||||
ai_embeddings_semantic_related_topics_enabled:
|
||||
default: false
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class CreateBgeTopicEmbeddingsTable < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn.new(truncation)
|
||||
|
||||
create_table vector_rep.table_name.to_sym, id: false do |t|
|
||||
t.integer :topic_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(#{vector_rep.dimensions})", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :topic_id, unique: true
|
||||
end
|
||||
end
|
||||
end
|
|
@ -8,6 +8,7 @@ module DiscourseAi
|
|||
require_relative "vector_representations/all_mpnet_base_v2"
|
||||
require_relative "vector_representations/text_embedding_ada_002"
|
||||
require_relative "vector_representations/multilingual_e5_large"
|
||||
require_relative "vector_representations/bge_large_en"
|
||||
require_relative "strategies/truncation"
|
||||
require_relative "jobs/regular/generate_embeddings"
|
||||
require_relative "semantic_related"
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class BgeLargeEn < Base
|
||||
def vector_from(text)
|
||||
DiscourseAi::Inference::CloudflareWorkersAi
|
||||
.perform!(inference_model_name, { text: text })
|
||||
.dig(:result, :data)
|
||||
.first
|
||||
end
|
||||
|
||||
def name
|
||||
"bge-large-en"
|
||||
end
|
||||
|
||||
def inference_model_name
|
||||
"baai/bge-large-en-v1.5"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
512
|
||||
end
|
||||
|
||||
def id
|
||||
4
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,25 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class CloudflareWorkersAi
|
||||
def self.perform!(model, content)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
|
||||
account_id = SiteSetting.ai_cloudflare_workers_account_id
|
||||
token = SiteSetting.ai_cloudflare_workers_api_token
|
||||
|
||||
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
|
||||
headers["Authorization"] = "Bearer #{token}"
|
||||
|
||||
endpoint = "#{base_url}#{model}"
|
||||
|
||||
response = Faraday.post(endpoint, content.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -66,6 +66,12 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
class BgeLargeEnTokenizer < BasicTokenizer
|
||||
def self.tokenizer
|
||||
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-large-en.json")
|
||||
end
|
||||
end
|
||||
|
||||
class OpenAiTokenizer < BasicTokenizer
|
||||
class << self
|
||||
def tokenizer
|
||||
|
|
|
@ -35,6 +35,7 @@ after_initialize do
|
|||
require_relative "lib/shared/inference/stability_generator"
|
||||
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||
require_relative "lib/shared/inference/amazon_bedrock_inference"
|
||||
require_relative "lib/shared/inference/cloudflare_workers_ai"
|
||||
require_relative "lib/shared/inference/function"
|
||||
require_relative "lib/shared/inference/function_list"
|
||||
|
||||
|
|
|
@ -134,3 +134,20 @@ describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
|
||||
describe "#size" do
|
||||
describe "returns a token count" do
|
||||
it "for a sentence with punctuation and capitalization and numbers" do
|
||||
expect(described_class.size("Hello, World! 123")).to eq(7)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#truncate" do
|
||||
it "truncates a sentence" do
|
||||
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -17,3 +17,7 @@ Licensed under LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
|||
## multilingual-e5-large
|
||||
|
||||
Licensed under MIT License
|
||||
|
||||
## bge-large-en
|
||||
|
||||
Licensed under MIT License
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue