mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-06-27 01:52:18 +00:00
HACK: Llama3 support for summarization/AI helper. (#616)
There are still some limitations to which models we can support with the `LlmModel` class. This will enable support for Llama3 while we sort those out.
This commit is contained in:
parent
62fc7d6ed0
commit
e22194f321
@ -31,7 +31,8 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
def create
|
def create
|
||||||
if llm_model = LlmModel.new(ai_llm_params).save
|
llm_model = LlmModel.new(ai_llm_params)
|
||||||
|
if llm_model.save
|
||||||
render json: { ai_persona: llm_model }, status: :created
|
render json: { ai_persona: llm_model }, status: :created
|
||||||
else
|
else
|
||||||
render_json_error llm_model
|
render_json_error llm_model
|
||||||
|
@ -219,6 +219,7 @@ en:
|
|||||||
open_ai: "OpenAI"
|
open_ai: "OpenAI"
|
||||||
google: "Google"
|
google: "Google"
|
||||||
azure: "Azure"
|
azure: "Azure"
|
||||||
|
ollama: "Ollama"
|
||||||
|
|
||||||
related_topics:
|
related_topics:
|
||||||
title: "Related Topics"
|
title: "Related Topics"
|
||||||
|
@ -16,6 +16,7 @@ module DiscourseAi
|
|||||||
DiscourseAi::Completions::Dialects::Mistral,
|
DiscourseAi::Completions::Dialects::Mistral,
|
||||||
DiscourseAi::Completions::Dialects::Claude,
|
DiscourseAi::Completions::Dialects::Claude,
|
||||||
DiscourseAi::Completions::Dialects::Command,
|
DiscourseAi::Completions::Dialects::Command,
|
||||||
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -24,14 +25,15 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
def dialect_for(model_name)
|
def dialect_for(model_name)
|
||||||
dialects = all_dialects
|
|
||||||
|
|
||||||
if Rails.env.test? || Rails.env.development?
|
if Rails.env.test? || Rails.env.development?
|
||||||
dialects << DiscourseAi::Completions::Dialects::Fake
|
dialects = [DiscourseAi::Completions::Dialects::Fake]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
dialects = dialects.concat(all_dialects)
|
||||||
|
|
||||||
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
||||||
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
||||||
|
|
||||||
dialect
|
dialect
|
||||||
end
|
end
|
||||||
|
|
||||||
|
55
lib/completions/dialects/open_ai_compatible.rb
Normal file
55
lib/completions/dialects/open_ai_compatible.rb
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class OpenAiCompatible < Dialect
|
||||||
|
class << self
|
||||||
|
def can_translate?(_model_name)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::Llama3Tokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def tools
|
||||||
|
@tools ||= tools_dialect.translated_tools
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||||
|
|
||||||
|
32_000
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def system_msg(msg)
|
||||||
|
{ role: "system", content: msg[:content] }
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_msg(msg)
|
||||||
|
{ role: "assistant", content: msg[:content] }
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_call_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool_call(msg)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool(msg)
|
||||||
|
end
|
||||||
|
|
||||||
|
def user_msg(msg)
|
||||||
|
content = +""
|
||||||
|
content << "#{msg[:id]}: " if msg[:id]
|
||||||
|
content << msg[:content]
|
||||||
|
|
||||||
|
{ role: "user", content: content }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -5,11 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Anthropic < Base
|
class Anthropic < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
endpoint_name == "anthropic" &&
|
endpoint_name == "anthropic"
|
||||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
|
|
||||||
model_name,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -7,21 +7,18 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class AwsBedrock < Base
|
class AwsBedrock < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
endpoint_name == "aws_bedrock" &&
|
endpoint_name == "aws_bedrock"
|
||||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
|
|
||||||
model_name,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?(model)
|
def correctly_configured?(_model)
|
||||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||||
SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
|
SiteSetting.ai_bedrock_region.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
def endpoint_name(model_name)
|
||||||
|
@ -8,7 +8,7 @@ module DiscourseAi
|
|||||||
TIMEOUT = 60
|
TIMEOUT = 60
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
def endpoint_for(provider_name, model_name)
|
def endpoint_for(provider_name)
|
||||||
endpoints = [
|
endpoints = [
|
||||||
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
||||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||||
@ -26,7 +26,7 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||||
ek.can_contact?(provider_name, model_name)
|
ek.can_contact?(provider_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ module DiscourseAi
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
def can_contact?(_endpoint_name, _model_name)
|
def can_contact?(_endpoint_name)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -6,7 +6,7 @@ module DiscourseAi
|
|||||||
class CannedResponse
|
class CannedResponse
|
||||||
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
||||||
|
|
||||||
def self.can_contact?(_, _)
|
def self.can_contact?(_)
|
||||||
Rails.env.test?
|
Rails.env.test?
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -5,17 +5,15 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Cohere < Base
|
class Cohere < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
return false unless endpoint_name == "cohere"
|
endpoint_name == "cohere"
|
||||||
|
|
||||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
%w[ai_cohere_api_key]
|
%w[ai_cohere_api_key]
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?(model_name)
|
def correctly_configured?(_model_name)
|
||||||
SiteSetting.ai_cohere_api_key.present?
|
SiteSetting.ai_cohere_api_key.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Fake < Base
|
class Fake < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(_endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
model_name == "fake"
|
endpoint_name == "fake"
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
def correctly_configured?(_model_name)
|
||||||
|
@ -5,9 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Gemini < Base
|
class Gemini < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
return false unless endpoint_name == "google"
|
endpoint_name == "google"
|
||||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -5,12 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class HuggingFace < Base
|
class HuggingFace < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
return false unless endpoint_name == "hugging_face"
|
endpoint_name == "hugging_face"
|
||||||
|
|
||||||
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
|
||||||
model_name,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Ollama < Base
|
class Ollama < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
endpoint_name == "ollama" && %w[mistral].include?(model_name)
|
endpoint_name == "ollama"
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -5,17 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class OpenAi < Base
|
class OpenAi < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
return false unless endpoint_name == "open_ai"
|
endpoint_name == "open_ai"
|
||||||
|
|
||||||
%w[
|
|
||||||
gpt-3.5-turbo
|
|
||||||
gpt-4
|
|
||||||
gpt-3.5-turbo-16k
|
|
||||||
gpt-4-32k
|
|
||||||
gpt-4-turbo
|
|
||||||
gpt-4-vision-preview
|
|
||||||
].include?(model_name)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -5,11 +5,8 @@ module DiscourseAi
|
|||||||
module Endpoints
|
module Endpoints
|
||||||
class Vllm < Base
|
class Vllm < Base
|
||||||
class << self
|
class << self
|
||||||
def can_contact?(endpoint_name, model_name)
|
def can_contact?(endpoint_name)
|
||||||
endpoint_name == "vllm" &&
|
endpoint_name == "vllm"
|
||||||
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
|
||||||
model_name,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -19,7 +19,9 @@ module DiscourseAi
|
|||||||
|
|
||||||
class << self
|
class << self
|
||||||
def provider_names
|
def provider_names
|
||||||
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
||||||
|
providers << "ollama" if Rails.env.development?
|
||||||
|
providers
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer_names
|
def tokenizer_names
|
||||||
@ -120,11 +122,7 @@ module DiscourseAi
|
|||||||
opts = {}
|
opts = {}
|
||||||
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
||||||
|
|
||||||
gateway_klass =
|
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
|
||||||
provider_name,
|
|
||||||
model_name_without_prov,
|
|
||||||
)
|
|
||||||
|
|
||||||
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
||||||
end
|
end
|
||||||
|
@ -13,8 +13,7 @@ module DiscourseAi
|
|||||||
begin
|
begin
|
||||||
llm_models =
|
llm_models =
|
||||||
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
|
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
|
||||||
endpoint =
|
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s)
|
||||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
|
|
||||||
|
|
||||||
models.map do |model_name|
|
models.map do |model_name|
|
||||||
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
|
||||||
|
@ -14,15 +14,17 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
provider_and_model_name = val.split(":")
|
provider_and_model_name = val.split(":")
|
||||||
|
|
||||||
provider_name = provider_and_model_name.first
|
provider_name = provider_and_model_name.first
|
||||||
model_name_without_prov = provider_and_model_name[1..].join
|
model_name_without_prov = provider_and_model_name[1..].join
|
||||||
|
is_custom_model = provider_name == "custom"
|
||||||
|
|
||||||
endpoint =
|
if is_custom_model
|
||||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
llm_model = LlmModel.find(model_name_without_prov)
|
||||||
provider_name,
|
provider_name = llm_model.provider
|
||||||
model_name_without_prov,
|
model_name_without_prov = llm_model.name
|
||||||
)
|
end
|
||||||
|
|
||||||
|
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||||
|
|
||||||
return false if endpoint.nil?
|
return false if endpoint.nil?
|
||||||
|
|
||||||
|
@ -50,9 +50,24 @@ module DiscourseAi
|
|||||||
max_tokens: 32_000,
|
max_tokens: 32_000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LlmModel.all.each do |model|
|
||||||
|
foldable_models << Models::CustomLlm.new(
|
||||||
|
"custom:#{model.id}",
|
||||||
|
max_tokens: model.max_prompt_tokens,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
foldable_models.each do |model|
|
foldable_models.each do |model|
|
||||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
plugin.add_model_callback(LlmModel, :after_create) do
|
||||||
|
new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens)
|
||||||
|
|
||||||
|
if ::Summarization::Base.find_strategy("custom:#{self.id}").nil?
|
||||||
|
plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model))
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
41
lib/summarization/models/custom_llm.rb
Normal file
41
lib/summarization/models/custom_llm.rb
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Summarization
|
||||||
|
module Models
|
||||||
|
class CustomLlm < Base
|
||||||
|
def display_name
|
||||||
|
custom_llm.display_name
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?
|
||||||
|
if Rails.env.development?
|
||||||
|
SiteSetting.ai_ollama_endpoint.present?
|
||||||
|
else
|
||||||
|
SiteSetting.ai_hugging_face_api_url.present? ||
|
||||||
|
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def configuration_hint
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.summarization.configuration_hint",
|
||||||
|
count: 1,
|
||||||
|
setting: "ai_hugging_face_api_url",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def model
|
||||||
|
model_name
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def custom_llm
|
||||||
|
id = model.split(":").last
|
||||||
|
@llm ||= LlmModel.find_by(id: id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -1,25 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Summarization
|
|
||||||
module Models
|
|
||||||
class Llama2 < Base
|
|
||||||
def display_name
|
|
||||||
"Llama2's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?
|
|
||||||
SiteSetting.ai_hugging_face_api_url.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def configuration_hint
|
|
||||||
I18n.t(
|
|
||||||
"discourse_ai.summarization.configuration_hint",
|
|
||||||
count: 1,
|
|
||||||
setting: "ai_hugging_face_api_url",
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,13 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Summarization
|
|
||||||
module Models
|
|
||||||
class Llama2FineTunedOrcaStyle < Llama2
|
|
||||||
def display_name
|
|
||||||
"Llama2FineTunedOrcaStyle's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
Loading…
x
Reference in New Issue
Block a user