HACK: Llama3 support for summarization/AI helper. (#616)

There are still some limitations to which models we can support with the `LlmModel` class. This will enable support for Llama3 while we sort those out.
This commit is contained in:
Roman Rizzi 2024-05-13 15:54:42 -03:00 committed by GitHub
parent 62fc7d6ed0
commit e22194f321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 157 additions and 106 deletions

View File

@ -31,7 +31,8 @@ module DiscourseAi
end
def create
if llm_model = LlmModel.new(ai_llm_params).save
llm_model = LlmModel.new(ai_llm_params)
if llm_model.save
render json: { ai_persona: llm_model }, status: :created
else
render_json_error llm_model

View File

@ -219,6 +219,7 @@ en:
open_ai: "OpenAI"
google: "Google"
azure: "Azure"
ollama: "Ollama"
related_topics:
title: "Related Topics"

View File

@ -16,6 +16,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
DiscourseAi::Completions::Dialects::OpenAiCompatible,
]
end
@ -24,14 +25,15 @@ module DiscourseAi
end
def dialect_for(model_name)
dialects = all_dialects
if Rails.env.test? || Rails.env.development?
dialects << DiscourseAi::Completions::Dialects::Fake
dialects = [DiscourseAi::Completions::Dialects::Fake]
end
dialects = dialects.concat(all_dialects)
dialect = dialects.find { |d| d.can_translate?(model_name) }
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
dialect
end

View File

@ -0,0 +1,55 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class OpenAiCompatible < Dialect
class << self
def can_translate?(_model_name)
true
end
def tokenizer
DiscourseAi::Tokenizer::Llama3Tokenizer
end
end
def tools
@tools ||= tools_dialect.translated_tools
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
32_000
end
private
def system_msg(msg)
{ role: "system", content: msg[:content] }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
{ role: "user", content: content }
end
end
end
end
end

View File

@ -5,11 +5,8 @@ module DiscourseAi
module Endpoints
class Anthropic < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "anthropic" &&
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
model_name,
)
def can_contact?(endpoint_name)
endpoint_name == "anthropic"
end
def dependant_setting_names

View File

@ -7,21 +7,18 @@ module DiscourseAi
module Endpoints
class AwsBedrock < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "aws_bedrock" &&
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
model_name,
)
def can_contact?(endpoint_name)
endpoint_name == "aws_bedrock"
end
def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end
def correctly_configured?(model)
def correctly_configured?(_model)
SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
SiteSetting.ai_bedrock_region.present?
end
def endpoint_name(model_name)

View File

@ -8,7 +8,7 @@ module DiscourseAi
TIMEOUT = 60
class << self
def endpoint_for(provider_name, model_name)
def endpoint_for(provider_name)
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::OpenAi,
@ -26,7 +26,7 @@ module DiscourseAi
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(provider_name, model_name)
ek.can_contact?(provider_name)
end
end
@ -55,7 +55,7 @@ module DiscourseAi
raise NotImplementedError
end
def can_contact?(_endpoint_name, _model_name)
def can_contact?(_endpoint_name)
raise NotImplementedError
end
end

View File

@ -6,7 +6,7 @@ module DiscourseAi
class CannedResponse
CANNED_RESPONSE_ERROR = Class.new(StandardError)
def self.can_contact?(_, _)
def self.can_contact?(_)
Rails.env.test?
end

View File

@ -5,17 +5,15 @@ module DiscourseAi
module Endpoints
class Cohere < Base
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "cohere"
%w[command-light command command-r command-r-plus].include?(model_name)
def can_contact?(endpoint_name)
endpoint_name == "cohere"
end
def dependant_setting_names
%w[ai_cohere_api_key]
end
def correctly_configured?(model_name)
def correctly_configured?(_model_name)
SiteSetting.ai_cohere_api_key.present?
end

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Endpoints
class Fake < Base
class << self
def can_contact?(_endpoint_name, model_name)
model_name == "fake"
def can_contact?(endpoint_name)
endpoint_name == "fake"
end
def correctly_configured?(_model_name)

View File

@ -5,9 +5,8 @@ module DiscourseAi
module Endpoints
class Gemini < Base
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "google"
%w[gemini-pro gemini-1.5-pro].include?(model_name)
def can_contact?(endpoint_name)
endpoint_name == "google"
end
def dependant_setting_names

View File

@ -5,12 +5,8 @@ module DiscourseAi
module Endpoints
class HuggingFace < Base
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "hugging_face"
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
def can_contact?(endpoint_name)
endpoint_name == "hugging_face"
end
def dependant_setting_names

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Endpoints
class Ollama < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "ollama" && %w[mistral].include?(model_name)
def can_contact?(endpoint_name)
endpoint_name == "ollama"
end
def dependant_setting_names

View File

@ -5,17 +5,8 @@ module DiscourseAi
module Endpoints
class OpenAi < Base
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "open_ai"
%w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
].include?(model_name)
def can_contact?(endpoint_name)
endpoint_name == "open_ai"
end
def dependant_setting_names

View File

@ -5,11 +5,8 @@ module DiscourseAi
module Endpoints
class Vllm < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "vllm" &&
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
def can_contact?(endpoint_name)
endpoint_name == "vllm"
end
def dependant_setting_names

View File

@ -19,7 +19,9 @@ module DiscourseAi
class << self
def provider_names
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
providers << "ollama" if Rails.env.development?
providers
end
def tokenizer_names
@ -120,11 +122,7 @@ module DiscourseAi
opts = {}
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
gateway_klass =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name,
model_name_without_prov,
)
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
end

View File

@ -13,8 +13,7 @@ module DiscourseAi
begin
llm_models =
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s)
models.map do |model_name|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }

View File

@ -14,15 +14,17 @@ module DiscourseAi
end
provider_and_model_name = val.split(":")
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name,
model_name_without_prov,
)
if is_custom_model
llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider
model_name_without_prov = llm_model.name
end
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
return false if endpoint.nil?

View File

@ -50,9 +50,24 @@ module DiscourseAi
max_tokens: 32_000,
)
LlmModel.all.each do |model|
foldable_models << Models::CustomLlm.new(
"custom:#{model.id}",
max_tokens: model.max_prompt_tokens,
)
end
foldable_models.each do |model|
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
end
plugin.add_model_callback(LlmModel, :after_create) do
new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens)
if ::Summarization::Base.find_strategy("custom:#{self.id}").nil?
plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model))
end
end
end
end
end

View File

@ -0,0 +1,41 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class CustomLlm < Base
def display_name
custom_llm.display_name
end
def correctly_configured?
if Rails.env.development?
SiteSetting.ai_ollama_endpoint.present?
else
SiteSetting.ai_hugging_face_api_url.present? ||
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
end
end
def configuration_hint
I18n.t(
"discourse_ai.summarization.configuration_hint",
count: 1,
setting: "ai_hugging_face_api_url",
)
end
def model
model_name
end
private
def custom_llm
id = model.split(":").last
@llm ||= LlmModel.find_by(id: id)
end
end
end
end
end

View File

@ -1,25 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class Llama2 < Base
def display_name
"Llama2's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
end
def correctly_configured?
SiteSetting.ai_hugging_face_api_url.present?
end
def configuration_hint
I18n.t(
"discourse_ai.summarization.configuration_hint",
count: 1,
setting: "ai_hugging_face_api_url",
)
end
end
end
end
end

View File

@ -1,13 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class Llama2FineTunedOrcaStyle < Llama2
def display_name
"Llama2FineTunedOrcaStyle's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
end
end
end
end
end