Refinements to embeddings and tokenizers (#61)

* Refinements to embeddings and tokenizers

* lint

* Truncate with tokenizers for summary

* fix
This commit is contained in:
Rafael dos Santos Silva 2023-05-15 15:10:42 -03:00 committed by GitHub
parent 93d9d9ea91
commit 3c9513e754
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 165 additions and 38 deletions

View File

@ -152,6 +152,10 @@ plugins:
- msmarco-distilbert-base-v4 - msmarco-distilbert-base-v4
- msmarco-distilbert-base-tas-b - msmarco-distilbert-base-tas-b
- text-embedding-ada-002 - text-embedding-ada-002
ai_embeddings_semantic_related_instruction:
default: "Represent the Discourse topic for retrieving relevant topics:"
hidden: true
client: false
ai_embeddings_generate_for_pms: false ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled: false ai_embeddings_semantic_related_topics_enabled: false
ai_embeddings_semantic_related_topics: 5 ai_embeddings_semantic_related_topics: 5

View File

@ -40,6 +40,10 @@ module DiscourseAi
&blk &blk
) )
end end
def tokenize(text)
DiscourseAi::Tokenizer::AnthropicTokenizer.tokenize(text)
end
end end
end end
end end

View File

@ -81,7 +81,7 @@ module DiscourseAi
conversation.reduce([]) do |memo, (raw, username)| conversation.reduce([]) do |memo, (raw, username)|
break(memo) if total_prompt_tokens >= prompt_limit break(memo) if total_prompt_tokens >= prompt_limit
tokens = DiscourseAi::Tokenizer.tokenize(raw) tokens = tokenize(raw)
if tokens.length + total_prompt_tokens > prompt_limit if tokens.length + total_prompt_tokens > prompt_limit
tokens = tokens[0...(prompt_limit - total_prompt_tokens)] tokens = tokens[0...(prompt_limit - total_prompt_tokens)]
@ -139,6 +139,10 @@ module DiscourseAi
user_ids: bot_reply_post.topic.allowed_user_ids, user_ids: bot_reply_post.topic.allowed_user_ids,
) )
end end
def tokenize(text)
raise NotImplemented
end
end end
end end
end end

View File

@ -43,6 +43,10 @@ module DiscourseAi
&blk &blk
) )
end end
def tokenize(text)
DiscourseAi::Tokenizer::OpenAiTokenizer.tokenize(text)
end
end end
end end
end end

View File

@ -14,8 +14,9 @@ module DiscourseAi
%i[symmetric], %i[symmetric],
"discourse", "discourse",
], ],
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
"msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"], "msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
"msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
"instructor-xl" => [768, 512, %i[cosine], %i[symmetric asymmetric], "discourse"],
"text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"], "text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
} }
@ -66,16 +67,27 @@ module DiscourseAi
private private
def discourse_embeddings(input) def discourse_embeddings(input)
truncated_input = DiscourseAi::Tokenizer::BertTokenizer.truncate(input, max_sequence_lenght)
if name.start_with?("instructor")
instructed_input = [
SiteSetting.ai_embeddings_semantic_related_instruction,
truncated_input,
]
end
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name.to_s, name.to_s,
input, instructed_input,
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,
) )
end end
def openai_embeddings(input) def openai_embeddings(input)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input) truncated_input =
DiscourseAi::Tokenizer::OpenAiTokenizer.truncate(input, max_sequence_lenght)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(truncated_input)
response[:data].first[:embedding] response[:data].first[:embedding]
end end
end end

View File

@ -11,7 +11,7 @@ module DiscourseAi
def summarize!(content_since) def summarize!(content_since)
content = get_content(content_since) content = get_content(content_since)
send("#{summarization_provider}_summarization", content[0..(max_length - 1)]) send("#{summarization_provider}_summarization", content)
end end
private private
@ -63,17 +63,22 @@ module DiscourseAi
end end
def discourse_summarization(content) def discourse_summarization(content)
truncated_content = DiscourseAi::Tokenizer::BertTokenizer.truncate(content, max_length)
::DiscourseAi::Inference::DiscourseClassifier.perform!( ::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_summarization_discourse_service_api_endpoint}/api/v1/classify", "#{SiteSetting.ai_summarization_discourse_service_api_endpoint}/api/v1/classify",
model, model,
content, truncated_content,
SiteSetting.ai_summarization_discourse_service_api_key, SiteSetting.ai_summarization_discourse_service_api_key,
).dig(:summary_text) ).dig(:summary_text)
end end
def openai_summarization(content) def openai_summarization(content)
truncated_content =
DiscourseAi::Tokenizer::OpenAiTokenizer.truncate(content, max_length - 50)
messages = [{ role: "system", content: <<~TEXT }] messages = [{ role: "system", content: <<~TEXT }]
Summarize the following article:\n\n#{content} Summarize the following article:\n\n#{truncated_content}
TEXT TEXT
::DiscourseAi::Inference::OpenAiCompletions.perform!(messages, model).dig( ::DiscourseAi::Inference::OpenAiCompletions.perform!(messages, model).dig(
@ -85,11 +90,14 @@ module DiscourseAi
end end
def anthropic_summarization(content) def anthropic_summarization(content)
truncated_content =
DiscourseAi::Tokenizer::AnthropicTokenizer.truncate(content, max_length - 50)
messages = messages =
"Human: Summarize the following article that is inside <input> tags. "Human: Summarize the following article that is inside <input> tags.
Plese include only the summary inside <ai> tags. Please include only the summary inside <ai> tags.
<input>##{content}</input> <input>##{truncated_content}</input>
Assistant: Assistant:
@ -107,13 +115,13 @@ module DiscourseAi
def max_length def max_length
lengths = { lengths = {
"bart-large-cnn-samsum" => 1024 * 4, "bart-large-cnn-samsum" => 1024,
"flan-t5-base-samsum" => 512 * 4, "flan-t5-base-samsum" => 512,
"long-t5-tglobal-base-16384-book-summary" => 16_384 * 4, "long-t5-tglobal-base-16384-book-summary" => 16_384,
"gpt-3.5-turbo" => 4096 * 4, "gpt-3.5-turbo" => 4096,
"gpt-4" => 8192 * 4, "gpt-4" => 8192,
"claude-v1" => 9000 * 4, "claude-v1" => 9000,
"claude-v1-100k" => 100_000 * 4, "claude-v1-100k" => 100_000,
} }
lengths[model] lengths[model]

View File

@ -60,8 +60,9 @@ module ::DiscourseAi
log.update!( log.update!(
raw_response_payload: response_body, raw_response_payload: response_body,
request_tokens: DiscourseAi::Tokenizer.size(prompt), request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
response_tokens: DiscourseAi::Tokenizer.size(parsed_response[:completion]), response_tokens:
DiscourseAi::Tokenizer::AnthropicTokenizer.size(parsed_response[:completion]),
) )
return parsed_response return parsed_response
end end
@ -97,8 +98,8 @@ module ::DiscourseAi
ensure ensure
log.update!( log.update!(
raw_response_payload: response_raw, raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer.size(prompt), request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
response_tokens: DiscourseAi::Tokenizer.size(response_data), response_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(response_data),
) )
end end
end end

View File

@ -98,8 +98,9 @@ module ::DiscourseAi
ensure ensure
log.update!( log.update!(
raw_response_payload: response_raw, raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer.size(extract_prompt(messages)), request_tokens:
response_tokens: DiscourseAi::Tokenizer.size(response_data), DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages)),
response_tokens: DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data),
) )
end end
end end

View File

@ -1,17 +1,49 @@
# frozen_string_literal: true # frozen_string_literal: true
module DiscourseAi module DiscourseAi
class Tokenizer module Tokenizer
def self.tokenizer class BasicTokenizer
@@tokenizer ||= def self.tokenizer
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json") raise NotImplementedError
end
def self.tokenize(text)
tokenizer.encode(text).tokens
end
def self.size(text)
tokenize(text).size
end
def self.truncate(text, max_length)
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
end
end end
def self.tokenize(text) class BertTokenizer < BasicTokenizer
tokenizer.encode(text).tokens def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json")
end
end end
def self.size(text)
tokenize(text).size class AnthropicTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/claude-v1-tokenization.json")
end
end
class OpenAiTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||= Tiktoken.get_encoding("cl100k_base")
end
def self.tokenize(text)
tokenizer.encode(text)
end
def self.truncate(text, max_length)
tokenizer.decode(tokenize(text).take(max_length))
end
end end
end end
end end

View File

@ -8,6 +8,7 @@
# required_version: 2.7.0 # required_version: 2.7.0
gem "tokenizers", "0.3.2", platform: RUBY_PLATFORM gem "tokenizers", "0.3.2", platform: RUBY_PLATFORM
gem "tiktoken_ruby", "0.0.5", platform: RUBY_PLATFORM
enabled_site_setting :discourse_ai_enabled enabled_site_setting :discourse_ai_enabled

View File

@ -2,38 +2,79 @@
require "rails_helper" require "rails_helper"
describe DiscourseAi::Tokenizer do describe DiscourseAi::Tokenizer::BertTokenizer do
describe "#size" do describe "#size" do
describe "returns a token count" do describe "returns a token count" do
it "for a single word" do it "for a single word" do
expect(DiscourseAi::Tokenizer.size("hello")).to eq(3) expect(described_class.size("hello")).to eq(3)
end end
it "for a sentence" do it "for a sentence" do
expect(DiscourseAi::Tokenizer.size("hello world")).to eq(4) expect(described_class.size("hello world")).to eq(4)
end end
it "for a sentence with punctuation" do it "for a sentence with punctuation" do
expect(DiscourseAi::Tokenizer.size("hello, world!")).to eq(6) expect(described_class.size("hello, world!")).to eq(6)
end end
it "for a sentence with punctuation and capitalization" do it "for a sentence with punctuation and capitalization" do
expect(DiscourseAi::Tokenizer.size("Hello, World!")).to eq(6) expect(described_class.size("Hello, World!")).to eq(6)
end end
it "for a sentence with punctuation and capitalization and numbers" do it "for a sentence with punctuation and capitalization and numbers" do
expect(DiscourseAi::Tokenizer.size("Hello, World! 123")).to eq(7) expect(described_class.size("Hello, World! 123")).to eq(7)
end end
end end
end end
describe "#tokenizer" do describe "#tokenizer" do
it "returns a tokenizer" do it "returns a tokenizer" do
expect(DiscourseAi::Tokenizer.tokenizer).to be_a(Tokenizers::Tokenizer) expect(described_class.tokenizer).to be_a(Tokenizers::Tokenizer)
end end
it "returns the same tokenizer" do it "returns the same tokenizer" do
expect(DiscourseAi::Tokenizer.tokenizer).to eq(DiscourseAi::Tokenizer.tokenizer) expect(described_class.tokenizer).to eq(described_class.tokenizer)
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
end
end
end
describe DiscourseAi::Tokenizer::AnthropicTokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(5)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar baz")
end
end
end
describe DiscourseAi::Tokenizer::OpenAiTokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(6)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar baz")
end end
end end
end end

7
tokenizers/MIT License Normal file
View File

@ -0,0 +1,7 @@
Copyright 2022 Anthropic, PBC.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

7
tokenizers/README.md Normal file
View File

@ -0,0 +1,7 @@
## bert-base-uncased.json
Licensed under Apache License
## claude-v1-tokenization.json
Licensed under MIT License

File diff suppressed because one or more lines are too long