HACK: Llama3 support for summarization/AI helper. (#616)
There are still some limitations to which models we can support with the `LlmModel` class. This will enable support for Llama3 while we sort those out.
This commit is contained in:
parent
62fc7d6ed0
commit
e22194f321
|
@ -31,7 +31,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def create
|
||||
if llm_model = LlmModel.new(ai_llm_params).save
|
||||
llm_model = LlmModel.new(ai_llm_params)
|
||||
if llm_model.save
|
||||
render json: { ai_persona: llm_model }, status: :created
|
||||
else
|
||||
render_json_error llm_model
|
||||
|
|
|
@ -219,6 +219,7 @@ en:
|
|||
open_ai: "OpenAI"
|
||||
google: "Google"
|
||||
azure: "Azure"
|
||||
ollama: "Ollama"
|
||||
|
||||
related_topics:
|
||||
title: "Related Topics"
|
||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Command,
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||
]
|
||||
end
|
||||
|
||||
|
@ -24,14 +25,15 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
dialects = all_dialects
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
dialects << DiscourseAi::Completions::Dialects::Fake
|
||||
dialects = [DiscourseAi::Completions::Dialects::Fake]
|
||||
end
|
||||
|
||||
dialects = dialects.concat(all_dialects)
|
||||
|
||||
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
||||
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
||||
|
||||
dialect
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class OpenAiCompatible < Dialect
|
||||
class << self
|
||||
def can_translate?(_model_name)
|
||||
true
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama3Tokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def tools
|
||||
@tools ||= tools_dialect.translated_tools
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
|
||||
32_000
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def system_msg(msg)
|
||||
{ role: "system", content: msg[:content] }
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
tools_dialect.from_raw_tool_call(msg)
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
tools_dialect.from_raw_tool(msg)
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
content = +""
|
||||
content << "#{msg[:id]}: " if msg[:id]
|
||||
content << msg[:content]
|
||||
|
||||
{ role: "user", content: content }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -5,11 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Anthropic < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "anthropic" &&
|
||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
|
||||
model_name,
|
||||
)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "anthropic"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -7,21 +7,18 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class AwsBedrock < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "aws_bedrock" &&
|
||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
|
||||
model_name,
|
||||
)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "aws_bedrock"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
||||
end
|
||||
|
||||
def correctly_configured?(model)
|
||||
def correctly_configured?(_model)
|
||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||
SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
|
||||
SiteSetting.ai_bedrock_region.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
|
|
|
@ -8,7 +8,7 @@ module DiscourseAi
|
|||
TIMEOUT = 60
|
||||
|
||||
class << self
|
||||
def endpoint_for(provider_name, model_name)
|
||||
def endpoint_for(provider_name)
|
||||
endpoints = [
|
||||
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||
|
@ -26,7 +26,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||
ek.can_contact?(provider_name, model_name)
|
||||
ek.can_contact?(provider_name)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -55,7 +55,7 @@ module DiscourseAi
|
|||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def can_contact?(_endpoint_name, _model_name)
|
||||
def can_contact?(_endpoint_name)
|
||||
raise NotImplementedError
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
|||
class CannedResponse
|
||||
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
||||
|
||||
def self.can_contact?(_, _)
|
||||
def self.can_contact?(_)
|
||||
Rails.env.test?
|
||||
end
|
||||
|
||||
|
|
|
@ -5,17 +5,15 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Cohere < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
return false unless endpoint_name == "cohere"
|
||||
|
||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "cohere"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_cohere_api_key]
|
||||
end
|
||||
|
||||
def correctly_configured?(model_name)
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_cohere_api_key.present?
|
||||
end
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Fake < Base
|
||||
class << self
|
||||
def can_contact?(_endpoint_name, model_name)
|
||||
model_name == "fake"
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "fake"
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
|
|
|
@ -5,9 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Gemini < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
return false unless endpoint_name == "google"
|
||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "google"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -5,12 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class HuggingFace < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
return false unless endpoint_name == "hugging_face"
|
||||
|
||||
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
||||
model_name,
|
||||
)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "hugging_face"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Ollama < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "ollama" && %w[mistral].include?(model_name)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "ollama"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -5,17 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class OpenAi < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
return false unless endpoint_name == "open_ai"
|
||||
|
||||
%w[
|
||||
gpt-3.5-turbo
|
||||
gpt-4
|
||||
gpt-3.5-turbo-16k
|
||||
gpt-4-32k
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
].include?(model_name)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "open_ai"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -5,11 +5,8 @@ module DiscourseAi
|
|||
module Endpoints
|
||||
class Vllm < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "vllm" &&
|
||||
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
||||
model_name,
|
||||
)
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "vllm"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
|
|
@ -19,7 +19,9 @@ module DiscourseAi
|
|||
|
||||
class << self
|
||||
def provider_names
|
||||
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
||||
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
||||
providers << "ollama" if Rails.env.development?
|
||||
providers
|
||||
end
|
||||
|
||||
def tokenizer_names
|
||||
|
@ -120,11 +122,7 @@ module DiscourseAi
|
|||
opts = {}
|
||||
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
||||
|
||||
gateway_klass =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
||||
provider_name,
|
||||
model_name_without_prov,
|
||||
)
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
||||
end
|
||||
|
|
|
@ -13,8 +13,7 @@ module DiscourseAi
|
|||
begin
|
||||
llm_models =
|
||||
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
|
||||
endpoint =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
|
||||
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s)
|
||||
|
||||
models.map do |model_name|
|
||||
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
|
||||
|
|
|
@ -14,15 +14,17 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
provider_and_model_name = val.split(":")
|
||||
|
||||
provider_name = provider_and_model_name.first
|
||||
model_name_without_prov = provider_and_model_name[1..].join
|
||||
is_custom_model = provider_name == "custom"
|
||||
|
||||
endpoint =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
||||
provider_name,
|
||||
model_name_without_prov,
|
||||
)
|
||||
if is_custom_model
|
||||
llm_model = LlmModel.find(model_name_without_prov)
|
||||
provider_name = llm_model.provider
|
||||
model_name_without_prov = llm_model.name
|
||||
end
|
||||
|
||||
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
return false if endpoint.nil?
|
||||
|
||||
|
|
|
@ -50,9 +50,24 @@ module DiscourseAi
|
|||
max_tokens: 32_000,
|
||||
)
|
||||
|
||||
LlmModel.all.each do |model|
|
||||
foldable_models << Models::CustomLlm.new(
|
||||
"custom:#{model.id}",
|
||||
max_tokens: model.max_prompt_tokens,
|
||||
)
|
||||
end
|
||||
|
||||
foldable_models.each do |model|
|
||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
||||
end
|
||||
|
||||
plugin.add_model_callback(LlmModel, :after_create) do
|
||||
new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens)
|
||||
|
||||
if ::Summarization::Base.find_strategy("custom:#{self.id}").nil?
|
||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Summarization
|
||||
module Models
|
||||
class CustomLlm < Base
|
||||
def display_name
|
||||
custom_llm.display_name
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
if Rails.env.development?
|
||||
SiteSetting.ai_ollama_endpoint.present?
|
||||
else
|
||||
SiteSetting.ai_hugging_face_api_url.present? ||
|
||||
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
|
||||
end
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
I18n.t(
|
||||
"discourse_ai.summarization.configuration_hint",
|
||||
count: 1,
|
||||
setting: "ai_hugging_face_api_url",
|
||||
)
|
||||
end
|
||||
|
||||
def model
|
||||
model_name
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def custom_llm
|
||||
id = model.split(":").last
|
||||
@llm ||= LlmModel.find_by(id: id)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,25 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Summarization
|
||||
module Models
|
||||
class Llama2 < Base
|
||||
def display_name
|
||||
"Llama2's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_hugging_face_api_url.present?
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
I18n.t(
|
||||
"discourse_ai.summarization.configuration_hint",
|
||||
count: 1,
|
||||
setting: "ai_hugging_face_api_url",
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,13 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Summarization
|
||||
module Models
|
||||
class Llama2FineTunedOrcaStyle < Llama2
|
||||
def display_name
|
||||
"Llama2FineTunedOrcaStyle's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue