DEV: port directory structure to Zeitwerk (#319)
Previous to this change we relied on explicit loading for a files in Discourse AI. This had a few downsides: - Busywork whenever you add a file (an extra require relative) - We were not keeping to conventions internally ... some places were OpenAI others are OpenAi - Autoloader did not work which lead to lots of full application broken reloads when developing. This moves all of DiscourseAI into a Zeitwerk compatible structure. It also leaves some minimal amount of manual loading (automation - which is loading into an existing namespace that may or may not be there) To avoid needing /lib/discourse_ai/... we mount a namespace thus we are able to keep /lib pointed at ::DiscourseAi Various files were renamed to get around zeitwerk rules and minimize usage of custom inflections Though we can get custom inflections to work it is not worth it, will require a Discourse core patch which means we create a hard dependency.
This commit is contained in:
parent
0b9947771c
commit
6ddc17fd61
|
@ -11,7 +11,7 @@ module Jobs
|
||||||
|
|
||||||
return if post.uploads.none? { |u| FileHelper.is_supported_image?(u.url) }
|
return if post.uploads.none? { |u| FileHelper.is_supported_image?(u.url) }
|
||||||
|
|
||||||
DiscourseAi::PostClassificator.new(DiscourseAi::NSFW::NSFWClassification.new).classify!(post)
|
DiscourseAi::PostClassificator.new(DiscourseAi::Nsfw::Classification.new).classify!(post)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
|
@ -53,6 +53,8 @@ en:
|
||||||
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
|
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
|
||||||
ai_helper_model: "Model to use for the AI helper."
|
ai_helper_model: "Model to use for the AI helper."
|
||||||
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
|
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
|
||||||
|
ai_helper_automatic_chat_thread_title_delay: "Delay in minutes before the AI helper automatically sets the chat thread title."
|
||||||
|
ai_helper_automatic_chat_thread_title: "Automatically set the chat thread titles based on thread contents."
|
||||||
|
|
||||||
ai_embeddings_enabled: "Enable the embeddings module."
|
ai_embeddings_enabled: "Enable the embeddings module."
|
||||||
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Commands
|
||||||
|
class Parameter
|
||||||
|
attr_reader :item_type, :name, :description, :type, :enum, :required
|
||||||
|
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
|
||||||
|
@name = name
|
||||||
|
@description = description
|
||||||
|
@type = type
|
||||||
|
@enum = enum
|
||||||
|
@required = required
|
||||||
|
@item_type = item_type
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -27,36 +27,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def load_files
|
|
||||||
require_relative "jobs/regular/create_ai_reply"
|
|
||||||
require_relative "jobs/regular/update_ai_bot_pm_title"
|
|
||||||
require_relative "bot"
|
|
||||||
require_relative "anthropic_bot"
|
|
||||||
require_relative "open_ai_bot"
|
|
||||||
require_relative "commands/command"
|
|
||||||
require_relative "commands/search_command"
|
|
||||||
require_relative "commands/categories_command"
|
|
||||||
require_relative "commands/tags_command"
|
|
||||||
require_relative "commands/time_command"
|
|
||||||
require_relative "commands/summarize_command"
|
|
||||||
require_relative "commands/image_command"
|
|
||||||
require_relative "commands/google_command"
|
|
||||||
require_relative "commands/read_command"
|
|
||||||
require_relative "commands/setting_context_command"
|
|
||||||
require_relative "commands/search_settings_command"
|
|
||||||
require_relative "commands/db_schema_command"
|
|
||||||
require_relative "commands/dall_e_command"
|
|
||||||
require_relative "personas/persona"
|
|
||||||
require_relative "personas/artist"
|
|
||||||
require_relative "personas/general"
|
|
||||||
require_relative "personas/sql_helper"
|
|
||||||
require_relative "personas/settings_explorer"
|
|
||||||
require_relative "personas/researcher"
|
|
||||||
require_relative "personas/creative"
|
|
||||||
require_relative "personas/dall_e_3"
|
|
||||||
require_relative "site_settings_extension"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
|
plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
|
||||||
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled
|
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled
|
||||||
|
@ -76,7 +46,7 @@ module DiscourseAi
|
||||||
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
||||||
end,
|
end,
|
||||||
) do
|
) do
|
||||||
Personas
|
DiscourseAi::AiBot::Personas
|
||||||
.all(user: scope.user)
|
.all(user: scope.user)
|
||||||
.map do |persona|
|
.map do |persona|
|
||||||
{ id: persona.id, name: persona.name, description: persona.description }
|
{ id: persona.id, name: persona.name, description: persona.description }
|
||||||
|
@ -135,8 +105,8 @@ module DiscourseAi
|
||||||
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
|
post.topic.custom_fields[REQUIRE_TITLE_UPDATE] = true
|
||||||
post.topic.save_custom_fields
|
post.topic.save_custom_fields
|
||||||
end
|
end
|
||||||
Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id)
|
::Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id)
|
||||||
Jobs.enqueue_in(
|
::Jobs.enqueue_in(
|
||||||
5.minutes,
|
5.minutes,
|
||||||
:update_ai_bot_pm_title,
|
:update_ai_bot_pm_title,
|
||||||
post_id: post.id,
|
post_id: post.id,
|
|
@ -0,0 +1,46 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
def self.system_personas
|
||||||
|
@system_personas ||= {
|
||||||
|
Personas::General => -1,
|
||||||
|
Personas::SqlHelper => -2,
|
||||||
|
Personas::Artist => -3,
|
||||||
|
Personas::SettingsExplorer => -4,
|
||||||
|
Personas::Researcher => -5,
|
||||||
|
Personas::Creative => -6,
|
||||||
|
Personas::DallE3 => -7,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.system_personas_by_id
|
||||||
|
@system_personas_by_id ||= system_personas.invert
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.all(user:)
|
||||||
|
# this needs to be dynamic cause site settings may change
|
||||||
|
all_available_commands = Persona.all_available_commands
|
||||||
|
|
||||||
|
AiPersona.all_personas.filter do |persona|
|
||||||
|
next false if !user.in_any_groups?(persona.allowed_group_ids)
|
||||||
|
|
||||||
|
if persona.system
|
||||||
|
instance = persona.new
|
||||||
|
(
|
||||||
|
instance.required_commands == [] ||
|
||||||
|
(instance.required_commands - all_available_commands).empty?
|
||||||
|
)
|
||||||
|
else
|
||||||
|
true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.find_by(id: nil, name: nil, user:)
|
||||||
|
all(user: user).find { |persona| persona.id == id || persona.name == name }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -3,46 +3,6 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
module Personas
|
module Personas
|
||||||
def self.system_personas
|
|
||||||
@system_personas ||= {
|
|
||||||
Personas::General => -1,
|
|
||||||
Personas::SqlHelper => -2,
|
|
||||||
Personas::Artist => -3,
|
|
||||||
Personas::SettingsExplorer => -4,
|
|
||||||
Personas::Researcher => -5,
|
|
||||||
Personas::Creative => -6,
|
|
||||||
Personas::DallE3 => -7,
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.system_personas_by_id
|
|
||||||
@system_personas_by_id ||= system_personas.invert
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.all(user:)
|
|
||||||
personas =
|
|
||||||
AiPersona.all_personas.filter { |persona| user.in_any_groups?(persona.allowed_group_ids) }
|
|
||||||
|
|
||||||
# this needs to be dynamic cause site settings may change
|
|
||||||
all_available_commands = Persona.all_available_commands
|
|
||||||
|
|
||||||
personas.filter do |persona|
|
|
||||||
if persona.system
|
|
||||||
instance = persona.new
|
|
||||||
(
|
|
||||||
instance.required_commands == [] ||
|
|
||||||
(instance.required_commands - all_available_commands).empty?
|
|
||||||
)
|
|
||||||
else
|
|
||||||
true
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.find_by(id: nil, name: nil, user:)
|
|
||||||
all(user: user).find { |persona| persona.id == id || persona.name == name }
|
|
||||||
end
|
|
||||||
|
|
||||||
class Persona
|
class Persona
|
||||||
def self.name
|
def self.name
|
||||||
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name")
|
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name")
|
|
@ -25,7 +25,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def generate_and_send_prompt(completion_prompt, input, user)
|
def generate_and_send_prompt(completion_prompt, input, user)
|
||||||
llm = DiscourseAi::Completions::LLM.proxy(SiteSetting.ai_helper_model)
|
llm = DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model)
|
||||||
|
|
||||||
generic_prompt = completion_prompt.messages_with_input(input)
|
generic_prompt = completion_prompt.messages_with_input(input)
|
||||||
|
|
|
@ -2,15 +2,6 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiHelper
|
module AiHelper
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
|
||||||
require_relative "chat_thread_titler"
|
|
||||||
require_relative "jobs/regular/generate_chat_thread_title"
|
|
||||||
require_relative "assistant"
|
|
||||||
require_relative "painter"
|
|
||||||
require_relative "semantic_categorizer"
|
|
||||||
require_relative "topic_helper"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
plugin.register_seedfu_fixtures(
|
plugin.register_seedfu_fixtures(
|
||||||
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"),
|
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"),
|
||||||
|
@ -22,7 +13,7 @@ module DiscourseAi
|
||||||
plugin.on(:chat_thread_created) do |thread|
|
plugin.on(:chat_thread_created) do |thread|
|
||||||
next unless SiteSetting.composer_ai_helper_enabled
|
next unless SiteSetting.composer_ai_helper_enabled
|
||||||
next unless SiteSetting.ai_helper_automatic_chat_thread_title
|
next unless SiteSetting.ai_helper_automatic_chat_thread_title
|
||||||
Jobs.enqueue_in(
|
::Jobs.enqueue_in(
|
||||||
SiteSetting.ai_helper_automatic_chat_thread_title_delay.minutes,
|
SiteSetting.ai_helper_automatic_chat_thread_title_delay.minutes,
|
||||||
:generate_chat_thread_title,
|
:generate_chat_thread_title,
|
||||||
thread_id: thread.id,
|
thread_id: thread.id,
|
|
@ -35,7 +35,7 @@ module DiscourseAi
|
||||||
You'll find the post between <input></input> XML tags.
|
You'll find the post between <input></input> XML tags.
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
DiscourseAi::Completions::LLM.proxy(SiteSetting.ai_helper_model).completion!(prompt, user)
|
DiscourseAi::Completions::Llm.proxy(SiteSetting.ai_helper_model).completion!(prompt, user)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
|
@ -3,7 +3,7 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class ChatGPT
|
class ChatGpt
|
||||||
def self.can_translate?(model_name)
|
def self.can_translate?(model_name)
|
||||||
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
|
@ -13,9 +13,9 @@ module DiscourseAi
|
||||||
[
|
[
|
||||||
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
||||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||||
DiscourseAi::Completions::Endpoints::OpenAI,
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||||
DiscourseAi::Completions::Endpoints::Huggingface,
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||||
].detect(-> { raise DiscourseAi::Completions::LLM::UNKNOWN_MODEL }) do |ek|
|
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||||
ek.can_contact?(model_name)
|
ek.can_contact?(model_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Huggingface < Base
|
class HuggingFace < Base
|
||||||
def self.can_contact?(model_name)
|
def self.can_contact?(model_name)
|
||||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name)
|
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class OpenAI < Base
|
class OpenAi < Base
|
||||||
def self.can_contact?(model_name)
|
def self.can_contact?(model_name)
|
||||||
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Completions
|
|
||||||
class EntryPoint
|
|
||||||
def load_files
|
|
||||||
require_relative "dialects/chat_gpt"
|
|
||||||
require_relative "dialects/llama2_classic"
|
|
||||||
require_relative "dialects/orca_style"
|
|
||||||
require_relative "dialects/claude"
|
|
||||||
|
|
||||||
require_relative "endpoints/canned_response"
|
|
||||||
require_relative "endpoints/base"
|
|
||||||
require_relative "endpoints/anthropic"
|
|
||||||
require_relative "endpoints/aws_bedrock"
|
|
||||||
require_relative "endpoints/open_ai"
|
|
||||||
require_relative "endpoints/hugging_face"
|
|
||||||
|
|
||||||
require_relative "llm"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(_)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -14,7 +14,7 @@
|
||||||
#
|
#
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
class LLM
|
class Llm
|
||||||
UNKNOWN_MODEL = Class.new(StandardError)
|
UNKNOWN_MODEL = Class.new(StandardError)
|
||||||
|
|
||||||
def self.with_prepared_responses(responses)
|
def self.with_prepared_responses(responses)
|
||||||
|
@ -27,7 +27,7 @@ module DiscourseAi
|
||||||
dialects = [
|
dialects = [
|
||||||
DiscourseAi::Completions::Dialects::Claude,
|
DiscourseAi::Completions::Dialects::Claude,
|
||||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||||
DiscourseAi::Completions::Dialects::ChatGPT,
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -3,21 +3,6 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
|
||||||
require_relative "vector_representations/base"
|
|
||||||
require_relative "vector_representations/all_mpnet_base_v2"
|
|
||||||
require_relative "vector_representations/text_embedding_ada_002"
|
|
||||||
require_relative "vector_representations/multilingual_e5_large"
|
|
||||||
require_relative "vector_representations/bge_large_en"
|
|
||||||
require_relative "strategies/truncation"
|
|
||||||
require_relative "jobs/regular/generate_embeddings"
|
|
||||||
require_relative "jobs/scheduled/embeddings_backfill"
|
|
||||||
require_relative "semantic_related"
|
|
||||||
require_relative "semantic_topic_query"
|
|
||||||
|
|
||||||
require_relative "semantic_search"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
# Include random topics in the suggested list *only* if there are no related topics.
|
# Include random topics in the suggested list *only* if there are no related topics.
|
||||||
plugin.register_modifier(
|
plugin.register_modifier(
|
|
@ -110,7 +110,7 @@ module DiscourseAi
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response =
|
llm_response =
|
||||||
DiscourseAi::Completions::LLM.proxy(
|
DiscourseAi::Completions::Llm.proxy(
|
||||||
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
SiteSetting.ai_embeddings_semantic_search_hyde_model,
|
||||||
).completion!(prompt, @guardian.user)
|
).completion!(prompt, @guardian.user)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module NSFW
|
module Nsfw
|
||||||
class NSFWClassification
|
class Classification
|
||||||
def type
|
def type
|
||||||
:nsfw
|
:nsfw
|
||||||
end
|
end
|
|
@ -1,18 +1,13 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module NSFW
|
module Nsfw
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
|
||||||
require_relative "nsfw_classification"
|
|
||||||
require_relative "jobs/regular/evaluate_post_uploads"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
nsfw_detection_cb =
|
nsfw_detection_cb =
|
||||||
Proc.new do |post|
|
Proc.new do |post|
|
||||||
if SiteSetting.ai_nsfw_detection_enabled &&
|
if SiteSetting.ai_nsfw_detection_enabled &&
|
||||||
DiscourseAi::NSFW::NSFWClassification.new.can_classify?(post)
|
DiscourseAi::Nsfw::Classification.new.can_classify?(post)
|
||||||
Jobs.enqueue(:evaluate_post_uploads, post_id: post.id)
|
Jobs.enqueue(:evaluate_post_uploads, post_id: post.id)
|
||||||
end
|
end
|
||||||
end
|
end
|
|
@ -3,11 +3,6 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Sentiment
|
module Sentiment
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
|
||||||
require_relative "sentiment_classification"
|
|
||||||
require_relative "jobs/regular/post_sentiment_analysis"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
sentiment_analysis_cb =
|
sentiment_analysis_cb =
|
||||||
Proc.new do |post|
|
Proc.new do |post|
|
|
@ -1,103 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Tokenizer
|
|
||||||
class BasicTokenizer
|
|
||||||
class << self
|
|
||||||
def tokenizer
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenize(text)
|
|
||||||
tokenizer.encode(text).tokens
|
|
||||||
end
|
|
||||||
|
|
||||||
def size(text)
|
|
||||||
tokenize(text).size
|
|
||||||
end
|
|
||||||
|
|
||||||
def truncate(text, max_length)
|
|
||||||
# Fast track the common case where the text is already short enough.
|
|
||||||
return text if text.size < max_length
|
|
||||||
|
|
||||||
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
|
|
||||||
end
|
|
||||||
|
|
||||||
def can_expand_tokens?(text, addition, max_length)
|
|
||||||
return true if text.size + addition.size < max_length
|
|
||||||
|
|
||||||
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class BertTokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class AnthropicTokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/claude-v1-tokenization.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class AllMpnetBaseV2Tokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class Llama2Tokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class MultilingualE5LargeTokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/multilingual-e5-large.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class BgeLargeEnTokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-large-en.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class OpenAiTokenizer < BasicTokenizer
|
|
||||||
class << self
|
|
||||||
def tokenizer
|
|
||||||
@@tokenizer ||= Tiktoken.get_encoding("cl100k_base")
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenize(text)
|
|
||||||
tokenizer.encode(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def truncate(text, max_length)
|
|
||||||
# Fast track the common case where the text is already short enough.
|
|
||||||
return text if text.size < max_length
|
|
||||||
|
|
||||||
tokenizer.decode(tokenize(text).take(max_length))
|
|
||||||
rescue Tiktoken::UnicodeError
|
|
||||||
max_length = max_length - 1
|
|
||||||
retry
|
|
||||||
end
|
|
||||||
|
|
||||||
def can_expand_tokens?(text, addition, max_length)
|
|
||||||
return true if text.size + addition.size < max_length
|
|
||||||
|
|
||||||
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -3,18 +3,6 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Summarization
|
module Summarization
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def load_files
|
|
||||||
require_relative "models/base"
|
|
||||||
require_relative "models/anthropic"
|
|
||||||
require_relative "models/discourse"
|
|
||||||
require_relative "models/open_ai"
|
|
||||||
require_relative "models/llama2"
|
|
||||||
require_relative "models/llama2_fine_tuned_orca_style"
|
|
||||||
|
|
||||||
require_relative "strategies/fold_content"
|
|
||||||
require_relative "strategies/truncate_content"
|
|
||||||
end
|
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
foldable_models = [
|
foldable_models = [
|
||||||
Models::OpenAi.new("gpt-4", max_tokens: 8192),
|
Models::OpenAi.new("gpt-4", max_tokens: 8192),
|
|
@ -19,7 +19,7 @@ module DiscourseAi
|
||||||
def summarize(content, user, &on_partial_blk)
|
def summarize(content, user, &on_partial_blk)
|
||||||
opts = content.except(:contents)
|
opts = content.except(:contents)
|
||||||
|
|
||||||
llm = DiscourseAi::Completions::LLM.proxy(completion_model.model)
|
llm = DiscourseAi::Completions::Llm.proxy(completion_model.model)
|
||||||
|
|
||||||
chunks = split_into_chunks(llm.tokenizer, content[:contents])
|
chunks = split_into_chunks(llm.tokenizer, content[:contents])
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class AllMpnetBaseV2Tokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,12 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class AnthropicTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/claude-v1-tokenization.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,34 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class BasicTokenizer
|
||||||
|
class << self
|
||||||
|
def tokenizer
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenize(text)
|
||||||
|
tokenizer.encode(text).tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def size(text)
|
||||||
|
tokenize(text).size
|
||||||
|
end
|
||||||
|
|
||||||
|
def truncate(text, max_length)
|
||||||
|
# Fast track the common case where the text is already short enough.
|
||||||
|
return text if text.size < max_length
|
||||||
|
|
||||||
|
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
|
||||||
|
end
|
||||||
|
|
||||||
|
def can_expand_tokens?(text, addition, max_length)
|
||||||
|
return true if text.size + addition.size < max_length
|
||||||
|
|
||||||
|
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,12 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class BertTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,11 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class BgeLargeEnTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bge-large-en.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,12 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class Llama2Tokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,12 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class MultilingualE5LargeTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/multilingual-e5-large.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue