From e22194f321014d865c52402bc0175226356ad014 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 13 May 2024 15:54:42 -0300 Subject: [PATCH] HACK: Llama3 support for summarization/AI helper. (#616) There are still some limitations to which models we can support with the `LlmModel` class. This will enable support for Llama3 while we sort those out. --- .../discourse_ai/admin/ai_llms_controller.rb | 3 +- config/locales/client.en.yml | 1 + lib/completions/dialects/dialect.rb | 8 ++- .../dialects/open_ai_compatible.rb | 55 +++++++++++++++++++ lib/completions/endpoints/anthropic.rb | 7 +-- lib/completions/endpoints/aws_bedrock.rb | 11 ++-- lib/completions/endpoints/base.rb | 6 +- lib/completions/endpoints/canned_response.rb | 2 +- lib/completions/endpoints/cohere.rb | 8 +-- lib/completions/endpoints/fake.rb | 4 +- lib/completions/endpoints/gemini.rb | 5 +- lib/completions/endpoints/hugging_face.rb | 8 +-- lib/completions/endpoints/ollama.rb | 4 +- lib/completions/endpoints/open_ai.rb | 13 +---- lib/completions/endpoints/vllm.rb | 7 +-- lib/completions/llm.rb | 10 ++-- lib/configuration/llm_enumerator.rb | 3 +- lib/configuration/llm_validator.rb | 14 +++-- lib/summarization/entry_point.rb | 15 +++++ lib/summarization/models/custom_llm.rb | 41 ++++++++++++++ lib/summarization/models/llama2.rb | 25 --------- .../models/llama2_fine_tuned_orca_style.rb | 13 ----- 22 files changed, 157 insertions(+), 106 deletions(-) create mode 100644 lib/completions/dialects/open_ai_compatible.rb create mode 100644 lib/summarization/models/custom_llm.rb delete mode 100644 lib/summarization/models/llama2.rb delete mode 100644 lib/summarization/models/llama2_fine_tuned_orca_style.rb diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb index 6e57f841..5e9859ce 100644 --- a/app/controllers/discourse_ai/admin/ai_llms_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -31,7 +31,8 @@ module DiscourseAi end def create - if llm_model = LlmModel.new(ai_llm_params).save + llm_model = LlmModel.new(ai_llm_params) + if llm_model.save render json: { ai_persona: llm_model }, status: :created else render_json_error llm_model diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 044d00f3..fe95f782 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -219,6 +219,7 @@ en: open_ai: "OpenAI" google: "Google" azure: "Azure" + ollama: "Ollama" related_topics: title: "Related Topics" diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 13275ad4..af26ab58 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -16,6 +16,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, + DiscourseAi::Completions::Dialects::OpenAiCompatible, ] end @@ -24,14 +25,15 @@ module DiscourseAi end def dialect_for(model_name) - dialects = all_dialects - if Rails.env.test? || Rails.env.development? - dialects << DiscourseAi::Completions::Dialects::Fake + dialects = [DiscourseAi::Completions::Dialects::Fake] end + dialects = dialects.concat(all_dialects) + dialect = dialects.find { |d| d.can_translate?(model_name) } raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect + dialect end diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb new file mode 100644 index 00000000..55d4e596 --- /dev/null +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class OpenAiCompatible < Dialect + class << self + def can_translate?(_model_name) + true + end + + def tokenizer + DiscourseAi::Tokenizer::Llama3Tokenizer + end + end + + def tools + @tools ||= tools_dialect.translated_tools + end + + def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + + 32_000 + end + + private + + def system_msg(msg) + { role: "system", content: msg[:content] } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + + def user_msg(msg) + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] + + { role: "user", content: content } + end + end + end + end +end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index ee9d4f17..a8fdbc04 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -5,11 +5,8 @@ module DiscourseAi module Endpoints class Anthropic < Base class << self - def can_contact?(endpoint_name, model_name) - endpoint_name == "anthropic" && - %w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?( - model_name, - ) + def can_contact?(endpoint_name) + endpoint_name == "anthropic" end def dependant_setting_names diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 9d15081d..daa6236e 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -7,21 +7,18 @@ module DiscourseAi module Endpoints class AwsBedrock < Base class << self - def can_contact?(endpoint_name, model_name) - endpoint_name == "aws_bedrock" && - %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?( - model_name, - ) + def can_contact?(endpoint_name) + endpoint_name == "aws_bedrock" end def dependant_setting_names %w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region] end - def correctly_configured?(model) + def correctly_configured?(_model) SiteSetting.ai_bedrock_access_key_id.present? && SiteSetting.ai_bedrock_secret_access_key.present? && - SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model) + SiteSetting.ai_bedrock_region.present? end def endpoint_name(model_name) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 9f34aa31..eede850f 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -8,7 +8,7 @@ module DiscourseAi TIMEOUT = 60 class << self - def endpoint_for(provider_name, model_name) + def endpoint_for(provider_name) endpoints = [ DiscourseAi::Completions::Endpoints::AwsBedrock, DiscourseAi::Completions::Endpoints::OpenAi, @@ -26,7 +26,7 @@ module DiscourseAi end endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek| - ek.can_contact?(provider_name, model_name) + ek.can_contact?(provider_name) end end @@ -55,7 +55,7 @@ module DiscourseAi raise NotImplementedError end - def can_contact?(_endpoint_name, _model_name) + def can_contact?(_endpoint_name) raise NotImplementedError end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index 1f869608..ab04961e 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -6,7 +6,7 @@ module DiscourseAi class CannedResponse CANNED_RESPONSE_ERROR = Class.new(StandardError) - def self.can_contact?(_, _) + def self.can_contact?(_) Rails.env.test? end diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb index 47190742..f300dc64 100644 --- a/lib/completions/endpoints/cohere.rb +++ b/lib/completions/endpoints/cohere.rb @@ -5,17 +5,15 @@ module DiscourseAi module Endpoints class Cohere < Base class << self - def can_contact?(endpoint_name, model_name) - return false unless endpoint_name == "cohere" - - %w[command-light command command-r command-r-plus].include?(model_name) + def can_contact?(endpoint_name) + endpoint_name == "cohere" end def dependant_setting_names %w[ai_cohere_api_key] end - def correctly_configured?(model_name) + def correctly_configured?(_model_name) SiteSetting.ai_cohere_api_key.present? end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index f322bb25..982d4242 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -5,8 +5,8 @@ module DiscourseAi module Endpoints class Fake < Base class << self - def can_contact?(_endpoint_name, model_name) - model_name == "fake" + def can_contact?(endpoint_name) + endpoint_name == "fake" end def correctly_configured?(_model_name) diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 55df3e18..8db86033 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -5,9 +5,8 @@ module DiscourseAi module Endpoints class Gemini < Base class << self - def can_contact?(endpoint_name, model_name) - return false unless endpoint_name == "google" - %w[gemini-pro gemini-1.5-pro].include?(model_name) + def can_contact?(endpoint_name) + endpoint_name == "google" end def dependant_setting_names diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index d6237c05..32d5605e 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -5,12 +5,8 @@ module DiscourseAi module Endpoints class HuggingFace < Base class << self - def can_contact?(endpoint_name, model_name) - return false unless endpoint_name == "hugging_face" - - %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) + def can_contact?(endpoint_name) + endpoint_name == "hugging_face" end def dependant_setting_names diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index 0fd748d4..97ad5f16 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -5,8 +5,8 @@ module DiscourseAi module Endpoints class Ollama < Base class << self - def can_contact?(endpoint_name, model_name) - endpoint_name == "ollama" && %w[mistral].include?(model_name) + def can_contact?(endpoint_name) + endpoint_name == "ollama" end def dependant_setting_names diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 2ccd817e..9aa00ea1 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -5,17 +5,8 @@ module DiscourseAi module Endpoints class OpenAi < Base class << self - def can_contact?(endpoint_name, model_name) - return false unless endpoint_name == "open_ai" - - %w[ - gpt-3.5-turbo - gpt-4 - gpt-3.5-turbo-16k - gpt-4-32k - gpt-4-turbo - gpt-4-vision-preview - ].include?(model_name) + def can_contact?(endpoint_name) + endpoint_name == "open_ai" end def dependant_setting_names diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 7db1452d..4c808a6a 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -5,11 +5,8 @@ module DiscourseAi module Endpoints class Vllm < Base class << self - def can_contact?(endpoint_name, model_name) - endpoint_name == "vllm" && - %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) + def can_contact?(endpoint_name) + endpoint_name == "vllm" end def dependant_setting_names diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index a2f70172..d5548890 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -19,7 +19,9 @@ module DiscourseAi class << self def provider_names - %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + providers << "ollama" if Rails.env.development? + providers end def tokenizer_names @@ -120,11 +122,7 @@ module DiscourseAi opts = {} opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model - gateway_klass = - DiscourseAi::Completions::Endpoints::Base.endpoint_for( - provider_name, - model_name_without_prov, - ) + gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts) end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb index fd870b57..d0c065af 100644 --- a/lib/configuration/llm_enumerator.rb +++ b/lib/configuration/llm_enumerator.rb @@ -13,8 +13,7 @@ module DiscourseAi begin llm_models = DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| - endpoint = - DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) + endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s) models.map do |model_name| { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } diff --git a/lib/configuration/llm_validator.rb b/lib/configuration/llm_validator.rb index 291bcca3..8391f085 100644 --- a/lib/configuration/llm_validator.rb +++ b/lib/configuration/llm_validator.rb @@ -14,15 +14,17 @@ module DiscourseAi end provider_and_model_name = val.split(":") - provider_name = provider_and_model_name.first model_name_without_prov = provider_and_model_name[1..].join + is_custom_model = provider_name == "custom" - endpoint = - DiscourseAi::Completions::Endpoints::Base.endpoint_for( - provider_name, - model_name_without_prov, - ) + if is_custom_model + llm_model = LlmModel.find(model_name_without_prov) + provider_name = llm_model.provider + model_name_without_prov = llm_model.name + end + + endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) return false if endpoint.nil? diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 37c72725..8e4a18c1 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -50,9 +50,24 @@ module DiscourseAi max_tokens: 32_000, ) + LlmModel.all.each do |model| + foldable_models << Models::CustomLlm.new( + "custom:#{model.id}", + max_tokens: model.max_prompt_tokens, + ) + end + foldable_models.each do |model| plugin.register_summarization_strategy(Strategies::FoldContent.new(model)) end + + plugin.add_model_callback(LlmModel, :after_create) do + new_model = Models::CustomLlm.new("custom:#{self.id}", max_tokens: self.max_prompt_tokens) + + if ::Summarization::Base.find_strategy("custom:#{self.id}").nil? + plugin.register_summarization_strategy(Strategies::FoldContent.new(new_model)) + end + end end end end diff --git a/lib/summarization/models/custom_llm.rb b/lib/summarization/models/custom_llm.rb new file mode 100644 index 00000000..67798326 --- /dev/null +++ b/lib/summarization/models/custom_llm.rb @@ -0,0 +1,41 @@ +# frozen_string_literal: true + +module DiscourseAi + module Summarization + module Models + class CustomLlm < Base + def display_name + custom_llm.display_name + end + + def correctly_configured? + if Rails.env.development? + SiteSetting.ai_ollama_endpoint.present? + else + SiteSetting.ai_hugging_face_api_url.present? || + SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present? + end + end + + def configuration_hint + I18n.t( + "discourse_ai.summarization.configuration_hint", + count: 1, + setting: "ai_hugging_face_api_url", + ) + end + + def model + model_name + end + + private + + def custom_llm + id = model.split(":").last + @llm ||= LlmModel.find_by(id: id) + end + end + end + end +end diff --git a/lib/summarization/models/llama2.rb b/lib/summarization/models/llama2.rb deleted file mode 100644 index 4942ae5c..00000000 --- a/lib/summarization/models/llama2.rb +++ /dev/null @@ -1,25 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Llama2 < Base - def display_name - "Llama2's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}" - end - - def correctly_configured? - SiteSetting.ai_hugging_face_api_url.present? - end - - def configuration_hint - I18n.t( - "discourse_ai.summarization.configuration_hint", - count: 1, - setting: "ai_hugging_face_api_url", - ) - end - end - end - end -end diff --git a/lib/summarization/models/llama2_fine_tuned_orca_style.rb b/lib/summarization/models/llama2_fine_tuned_orca_style.rb deleted file mode 100644 index 81ff6bda..00000000 --- a/lib/summarization/models/llama2_fine_tuned_orca_style.rb +++ /dev/null @@ -1,13 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Summarization - module Models - class Llama2FineTunedOrcaStyle < Llama2 - def display_name - "Llama2FineTunedOrcaStyle's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}" - end - end - end - end -end