FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API (#241)
* FEATURE: Bge-large-en embeddings via Cloudflare Workers AI API * forgot a file * lint
This commit is contained in:
parent
05c256f65b
commit
84cc369552
|
@ -82,6 +82,12 @@ en:
|
||||||
ai_google_custom_search_api_key: "API key for the Google Custom Search API see: https://developers.google.com/custom-search"
|
ai_google_custom_search_api_key: "API key for the Google Custom Search API see: https://developers.google.com/custom-search"
|
||||||
ai_google_custom_search_cx: "CX for Google Custom Search API"
|
ai_google_custom_search_cx: "CX for Google Custom Search API"
|
||||||
|
|
||||||
|
ai_bedrock_access_key_id: "Access key ID for the Bedrock API"
|
||||||
|
ai_bedrock_secret_access_key: "Secret access key for the Bedrock API"
|
||||||
|
ai_bedrock_region: "AWS region for the Bedrock API"
|
||||||
|
ai_cloudflare_workers_account_id: "Cloudflare account ID for the Cloudflare Workers AI API"
|
||||||
|
ai_cloudflare_workers_api_token: "API token for the Cloudflare Workers AI API"
|
||||||
|
|
||||||
reviewables:
|
reviewables:
|
||||||
reasons:
|
reasons:
|
||||||
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
||||||
|
@ -122,7 +128,6 @@ en:
|
||||||
creative:
|
creative:
|
||||||
name: Creative
|
name: Creative
|
||||||
description: "AI Bot with no external integrations specialized in creative tasks"
|
description: "AI Bot with no external integrations specialized in creative tasks"
|
||||||
default_pm_prefix: "[Untitled AI bot PM]"
|
|
||||||
topic_not_found: "Summary unavailable, topic not found!"
|
topic_not_found: "Summary unavailable, topic not found!"
|
||||||
searching: "Searching for: '%{query}'"
|
searching: "Searching for: '%{query}'"
|
||||||
command_summary:
|
command_summary:
|
||||||
|
|
|
@ -133,6 +133,12 @@ discourse_ai:
|
||||||
secret: true
|
secret: true
|
||||||
ai_bedrock_region:
|
ai_bedrock_region:
|
||||||
default: "us-east-1"
|
default: "us-east-1"
|
||||||
|
ai_cloudflare_workers_account_id:
|
||||||
|
default: ""
|
||||||
|
secret: true
|
||||||
|
ai_cloudflare_workers_api_token:
|
||||||
|
default: ""
|
||||||
|
secret: true
|
||||||
|
|
||||||
composer_ai_helper_enabled:
|
composer_ai_helper_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
@ -179,6 +185,7 @@ discourse_ai:
|
||||||
- all-mpnet-base-v2
|
- all-mpnet-base-v2
|
||||||
- text-embedding-ada-002
|
- text-embedding-ada-002
|
||||||
- multilingual-e5-large
|
- multilingual-e5-large
|
||||||
|
- bge-large-en
|
||||||
ai_embeddings_generate_for_pms: false
|
ai_embeddings_generate_for_pms: false
|
||||||
ai_embeddings_semantic_related_topics_enabled:
|
ai_embeddings_semantic_related_topics_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class CreateBgeTopicEmbeddingsTable < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn.new(truncation)
|
||||||
|
|
||||||
|
create_table vector_rep.table_name.to_sym, id: false do |t|
|
||||||
|
t.integer :topic_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(#{vector_rep.dimensions})", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :topic_id, unique: true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -8,6 +8,7 @@ module DiscourseAi
|
||||||
require_relative "vector_representations/all_mpnet_base_v2"
|
require_relative "vector_representations/all_mpnet_base_v2"
|
||||||
require_relative "vector_representations/text_embedding_ada_002"
|
require_relative "vector_representations/text_embedding_ada_002"
|
||||||
require_relative "vector_representations/multilingual_e5_large"
|
require_relative "vector_representations/multilingual_e5_large"
|
||||||
|
require_relative "vector_representations/bge_large_en"
|
||||||
require_relative "strategies/truncation"
|
require_relative "strategies/truncation"
|
||||||
require_relative "jobs/regular/generate_embeddings"
|
require_relative "jobs/regular/generate_embeddings"
|
||||||
require_relative "semantic_related"
|
require_relative "semantic_related"
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class BgeLargeEn < Base
|
||||||
|
def vector_from(text)
|
||||||
|
DiscourseAi::Inference::CloudflareWorkersAi
|
||||||
|
.perform!(inference_model_name, { text: text })
|
||||||
|
.dig(:result, :data)
|
||||||
|
.first
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"bge-large-en"
|
||||||
|
end
|
||||||
|
|
||||||
|
def inference_model_name
|
||||||
|
"baai/bge-large-en-v1.5"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1024
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
512
|
||||||
|
end
|
||||||
|
|
||||||
|
def id
|
||||||
|
4
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<#>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_ip_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,25 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module ::DiscourseAi
|
||||||
|
module Inference
|
||||||
|
class CloudflareWorkersAi
|
||||||
|
def self.perform!(model, content)
|
||||||
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
account_id = SiteSetting.ai_cloudflare_workers_account_id
|
||||||
|
token = SiteSetting.ai_cloudflare_workers_api_token
|
||||||
|
|
||||||
|
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
|
||||||
|
headers["Authorization"] = "Bearer #{token}"
|
||||||
|
|
||||||
|
endpoint = "#{base_url}#{model}"
|
||||||
|
|
||||||
|
response = Faraday.post(endpoint, content.to_json, headers)
|
||||||
|
|
||||||
|
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -66,6 +66,12 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
class BgeLargeEnTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-large-en.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
class OpenAiTokenizer < BasicTokenizer
|
class OpenAiTokenizer < BasicTokenizer
|
||||||
class << self
|
class << self
|
||||||
def tokenizer
|
def tokenizer
|
||||||
|
|
|
@ -35,6 +35,7 @@ after_initialize do
|
||||||
require_relative "lib/shared/inference/stability_generator"
|
require_relative "lib/shared/inference/stability_generator"
|
||||||
require_relative "lib/shared/inference/hugging_face_text_generation"
|
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||||
require_relative "lib/shared/inference/amazon_bedrock_inference"
|
require_relative "lib/shared/inference/amazon_bedrock_inference"
|
||||||
|
require_relative "lib/shared/inference/cloudflare_workers_ai"
|
||||||
require_relative "lib/shared/inference/function"
|
require_relative "lib/shared/inference/function"
|
||||||
require_relative "lib/shared/inference/function_list"
|
require_relative "lib/shared/inference/function_list"
|
||||||
|
|
||||||
|
|
|
@ -134,3 +134,20 @@ describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe DiscourseAi::Tokenizer::BgeLargeEnTokenizer do
|
||||||
|
describe "#size" do
|
||||||
|
describe "returns a token count" do
|
||||||
|
it "for a sentence with punctuation and capitalization and numbers" do
|
||||||
|
expect(described_class.size("Hello, World! 123")).to eq(7)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#truncate" do
|
||||||
|
it "truncates a sentence" do
|
||||||
|
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||||
|
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
|
@ -17,3 +17,7 @@ Licensed under LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
||||||
## multilingual-e5-large
|
## multilingual-e5-large
|
||||||
|
|
||||||
Licensed under MIT License
|
Licensed under MIT License
|
||||||
|
|
||||||
|
## bge-large-en
|
||||||
|
|
||||||
|
Licensed under MIT License
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue