diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 9d3dffd6..6efa5dc9 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -240,3 +240,14 @@ en: anger: "Anger" joy: "Joy" disgust: "Disgust" + + llm: + configuration: + disable_module_first: "You have to disable %{setting} first." + set_llm_first: "Set %{setting} first." + model_unreachable: "We couldn't get a response from this model. Check your settings first." + endpoints: + not_configured: "%{display_name} (not configured)" + configuration_hint: + one: "Make sure the `%{settings}` setting was configured." + other: "Make sure these settings were configured: %{settings}" diff --git a/config/settings.yml b/config/settings.yml index a96c2035..97ea3db2 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -170,6 +170,7 @@ discourse_ai: composer_ai_helper_enabled: default: false client: true + validator: "DiscourseAi::Configuration::LlmDependencyValidator" ai_helper_allowed_groups: client: true type: group_list @@ -181,17 +182,11 @@ discourse_ai: default: false client: true ai_helper_model: - default: gpt-3.5-turbo + default: "" + allow_any: false type: enum - choices: - - gpt-3.5-turbo - - gpt-4 - - claude-2 - - stable-beluga-2 - - Llama2-chat-hf - - gemini-pro - - mistralai/Mixtral-8x7B-Instruct-v0.1 - - mistralai/Mistral-7B-Instruct-v0.2 + enum: "DiscourseAi::Configuration::LlmEnumerator" + validator: "DiscourseAi::Configuration::LlmValidator" ai_helper_custom_prompts_allowed_groups: client: true type: group_list @@ -257,21 +252,13 @@ discourse_ai: ai_embeddings_semantic_search_enabled: default: false client: true + validator: "DiscourseAi::Configuration::LlmDependencyValidator" ai_embeddings_semantic_search_hyde_model: - default: "gpt-3.5-turbo" + default: "" type: enum allow_any: false - choices: - - Llama2-*-chat-hf - - claude-instant-1 - - claude-2 - - gpt-3.5-turbo - - gpt-4 - - StableBeluga2 - - Upstage-Llama-2-*-instruct-v2 - - gemini-pro - - mistralai/Mixtral-8x7B-Instruct-v0.1 - - mistralai/Mistral-7B-Instruct-v0.2 + enum: "DiscourseAi::Configuration::LlmEnumerator" + validator: "DiscourseAi::Configuration::LlmValidator" ai_summarization_discourse_service_api_endpoint: "" ai_summarization_discourse_service_api_endpoint_srv: diff --git a/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb b/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb new file mode 100644 index 00000000..c57acb09 --- /dev/null +++ b/db/post_migrate/20240119152348_explicit_provider_backwards_compat.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +class ExplicitProviderBackwardsCompat < ActiveRecord::Migration[7.0] + def up + backfill_settings("composer_ai_helper_enabled", "ai_helper_model") + backfill_settings( + "ai_embeddings_semantic_search_enabled", + "ai_embeddings_semantic_search_hyde_model", + ) + end + + def down + raise ActiveRecord::IrreversibleMigration + end + + def backfill_settings(feature_setting_name, llm_setting_name) + feature_enabled = + DB.query_single( + "SELECT value FROM site_settings WHERE name = :setting_name", + setting_name: feature_setting_name, + ).first == "t" + + setting_value = + DB + .query_single( + "SELECT value FROM site_settings WHERE name = :llm_setting", + llm_setting: llm_setting_name, + ) + .first + .to_s + + providers = %w[aws_bedrock anthropic open_ai hugging_face vllm google] + # Sanity check to make sure we won't add provider twice. + return if providers.include?(setting_value.split(":").first) + + if !setting_value && feature_enabled + # Enabled and using old default (gpt-3.5-turbo) + DB.exec( + "UPDATE site_settings SET value='open_ai:gpt-3.5-turbo' WHERE name=:llm_setting", + llm_setting: llm_setting_name, + ) + elsif setting_value && !feature_enabled + # They'll have to choose an LLM model again before enabling the feature + DB.exec("DELETE FROM site_settings WHERE name=:llm_setting", llm_setting: llm_setting_name) + elsif setting_value && feature_enabled + DB.exec( + "UPDATE site_settings SET value=:new_value WHERE name=:llm_setting", + llm_setting: llm_setting_name, + new_value: append_provider(setting_value), + ) + end + end + + def append_provider(value) + open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo] + return "open_ai:#{value}" if open_ai_models.include?(value) + return "google:#{value}" if value == "gemini-pro" + + hf_models = %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf] + return "hugging_face:#{value}" if hf_models.include?(value) + + # Models available through multiple providers + claude_models = %w[claude-instant-1 claude-2] + if claude_models.include?(value) + has_bedrock_creds = + DB.query_single( + "SELECT value FROM site_settings WHERE name = 'ai_bedrock_secret_access_key' OR name = 'ai_bedrock_access_key_id' ", + ).length > 0 + + if has_bedrock_creds + return "aws_bedrock:#{value}" + else + return "anthropic:#{value}" + end + end + + mixtral_models = %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2] + if mixtral_models.include?(value) + vllm_configured = + DB.query_single( + "SELECT value FROM site_settings WHERE name = 'ai_vllm_endpoint_srv' OR name = 'ai_vllm_endpoint' ", + ).length > 0 + + if vllm_configured + "vllm:#{value}" + else + "hugging_face:#{value}" + end + end + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 26996bf5..a5a2f329 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -125,28 +125,39 @@ module DiscourseAi end def model(prefer_low_cost: false) + # HACK(roman): We'll do this until we define how we represent different providers in the bot settings default_model = case bot_user.id when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID - "claude-2" + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end when DiscourseAi::AiBot::EntryPoint::GPT4_ID - "gpt-4" + "open_ai:gpt-4" when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID - "gpt-4-turbo" + "open_ai:gpt-4-turbo" when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID - "gpt-3.5-turbo-16k" + "open_ai:gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" + else + "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" + end when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "gemini-pro" + "google:gemini-pro" when DiscourseAi::AiBot::EntryPoint::FAKE_ID - "fake" + "fake:fake" else nil end - if %w[gpt-4 gpt-4-turbo].include?(default_model) && prefer_low_cost - return "gpt-3.5-turbo-16k" + if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost + return "open_ai:gpt-3.5-turbo-16k" end default_model diff --git a/lib/completions/dialects/fake.rb b/lib/completions/dialects/fake.rb index 97c1102f..c569ee28 100644 --- a/lib/completions/dialects/fake.rb +++ b/lib/completions/dialects/fake.rb @@ -9,6 +9,10 @@ module DiscourseAi model_name == "fake" end + def translate + "" + end + def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 5484365b..0f766c4d 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -4,8 +4,22 @@ module DiscourseAi module Completions module Endpoints class Anthropic < Base - def self.can_contact?(model_name) - %w[claude-instant-1 claude-2].include?(model_name) + class << self + def can_contact?(endpoint_name, model_name) + endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name) + end + + def dependant_setting_names + %w[ai_anthropic_api_key] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_anthropic_api_key.present? + end + + def endpoint_name(model_name) + "Anthropic - #{model_name}" + end end def normalize_model_params(model_params) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 03782eb8..96e86d76 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -6,11 +6,24 @@ module DiscourseAi module Completions module Endpoints class AwsBedrock < Base - def self.can_contact?(model_name) - %w[claude-instant-1 claude-2].include?(model_name) && + class << self + def can_contact?(endpoint_name, model_name) + endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name) + end + + def dependant_setting_names + %w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region] + end + + def correctly_configured?(_model_name) SiteSetting.ai_bedrock_access_key_id.present? && - SiteSetting.ai_bedrock_secret_access_key.present? && - SiteSetting.ai_bedrock_region.present? + SiteSetting.ai_bedrock_secret_access_key.present? && + SiteSetting.ai_bedrock_region.present? + end + + def endpoint_name(model_name) + "AWS Bedrock - #{model_name}" + end end def normalize_model_params(model_params) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index df1712a7..2b842327 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -7,29 +7,54 @@ module DiscourseAi CompletionFailed = Class.new(StandardError) TIMEOUT = 60 - def self.endpoint_for(model_name) - # Order is important. - # Bedrock has priority over Anthropic if credentials are present. - endpoints = [ - DiscourseAi::Completions::Endpoints::AwsBedrock, - DiscourseAi::Completions::Endpoints::Anthropic, - DiscourseAi::Completions::Endpoints::OpenAi, - DiscourseAi::Completions::Endpoints::HuggingFace, - DiscourseAi::Completions::Endpoints::Gemini, - DiscourseAi::Completions::Endpoints::Vllm, - ] + class << self + def endpoint_for(provider_name, model_name) + endpoints = [ + DiscourseAi::Completions::Endpoints::AwsBedrock, + DiscourseAi::Completions::Endpoints::Anthropic, + DiscourseAi::Completions::Endpoints::OpenAi, + DiscourseAi::Completions::Endpoints::HuggingFace, + DiscourseAi::Completions::Endpoints::Gemini, + DiscourseAi::Completions::Endpoints::Vllm, + ] - if Rails.env.test? || Rails.env.development? - endpoints << DiscourseAi::Completions::Endpoints::Fake + if Rails.env.test? || Rails.env.development? + endpoints << DiscourseAi::Completions::Endpoints::Fake + end + + endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek| + ek.can_contact?(provider_name, model_name) + end end - endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek| - ek.can_contact?(model_name) + def configuration_hint + settings = dependant_setting_names + I18n.t( + "discourse_ai.llm.endpoints.configuration_hint", + settings: settings.join(", "), + count: settings.length, + ) end - end - def self.can_contact?(_model_name) - raise NotImplementedError + def display_name(model_name) + to_display = endpoint_name(model_name) + + return to_display if correctly_configured?(model_name) + + I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display) + end + + def dependant_setting_names + raise NotImplementedError + end + + def endpoint_name(_model_name) + raise NotImplementedError + end + + def can_contact?(_endpoint_name, _model_name) + raise NotImplementedError + end end def initialize(model_name, tokenizer) diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index ab04961e..1f869608 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -6,7 +6,7 @@ module DiscourseAi class CannedResponse CANNED_RESPONSE_ERROR = Class.new(StandardError) - def self.can_contact?(_) + def self.can_contact?(_, _) Rails.env.test? end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index 1b8d9df2..45e2cea7 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -4,8 +4,18 @@ module DiscourseAi module Completions module Endpoints class Fake < Base - def self.can_contact?(model_name) - model_name == "fake" + class << self + def can_contact?(_endpoint_name, model_name) + model_name == "fake" + end + + def correctly_configured?(_model_name) + true + end + + def endpoint_name(_model_name) + "Test - fake model" + end end STOCK_CONTENT = <<~TEXT diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index f0b7a508..2bf2dc14 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -4,8 +4,23 @@ module DiscourseAi module Completions module Endpoints class Gemini < Base - def self.can_contact?(model_name) - %w[gemini-pro].include?(model_name) + class << self + def can_contact?(endpoint_name, model_name) + return false unless endpoint_name == "google" + %w[gemini-pro].include?(model_name) + end + + def dependant_setting_names + %w[ai_gemini_api_key] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_gemini_api_key.present? + end + + def endpoint_name(model_name) + "Google - #{model_name}" + end end def default_options diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index 4a0f2875..5542c73f 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -4,15 +4,31 @@ module DiscourseAi module Completions module Endpoints class HuggingFace < Base - def self.can_contact?(model_name) - %w[ - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - ].include?(model_name) && SiteSetting.ai_hugging_face_api_url.present? + class << self + def can_contact?(endpoint_name, model_name) + return false unless endpoint_name == "hugging_face" + + %w[ + StableBeluga2 + Upstage-Llama-2-*-instruct-v2 + Llama2-*-chat-hf + Llama2-chat-hf + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + ].include?(model_name) + end + + def dependant_setting_names + %w[ai_hugging_face_api_url] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_hugging_face_api_url.present? + end + + def endpoint_name(model_name) + "Hugging Face - #{model_name}" + end end def default_options diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 10b5cd91..33382b0e 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -4,15 +4,62 @@ module DiscourseAi module Completions module Endpoints class OpenAi < Base - def self.can_contact?(model_name) - %w[ - gpt-3.5-turbo - gpt-4 - gpt-3.5-turbo-16k - gpt-4-32k - gpt-4-0125-preview - gpt-4-turbo - ].include?(model_name) + class << self + def can_contact?(endpoint_name, model_name) + return false unless endpoint_name == "open_ai" + + %w[ + gpt-3.5-turbo + gpt-4 + gpt-3.5-turbo-16k + gpt-4-32k + gpt-4-0125-preview + gpt-4-turbo + ].include?(model_name) + end + + def dependant_setting_names + %w[ + ai_openai_api_key + ai_openai_gpt4_32k_url + ai_openai_gpt4_turbo_url + ai_openai_gpt4_url + ai_openai_gpt4_url + ai_openai_gpt35_16k_url + ai_openai_gpt35_url + ] + end + + def correctly_configured?(model_name) + SiteSetting.ai_openai_api_key.present? && has_url?(model_name) + end + + def has_url?(model) + url = + if model.include?("gpt-4") + if model.include?("32k") + SiteSetting.ai_openai_gpt4_32k_url + else + if model.include?("1106") || model.include?("turbo") + SiteSetting.ai_openai_gpt4_turbo_url + else + SiteSetting.ai_openai_gpt4_url + end + end + else + if model.include?("16k") + SiteSetting.ai_openai_gpt35_16k_url + else + SiteSetting.ai_openai_gpt35_url + end + end + + url.present? + end + + def endpoint_name(model_name) + "OpenAI - #{model_name}" + end end def normalize_model_params(model_params) diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 1ea69bbd..310bcdc9 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -4,10 +4,30 @@ module DiscourseAi module Completions module Endpoints class Vllm < Base - def self.can_contact?(model_name) - %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) + class << self + def can_contact?(endpoint_name, model_name) + endpoint_name == "vllm" && + %w[ + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + StableBeluga2 + Upstage-Llama-2-*-instruct-v2 + Llama2-*-chat-hf + Llama2-chat-hf + ].include?(model_name) + end + + def dependant_setting_names + %w[ai_vllm_endpoint_srv ai_vllm_endpoint] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present? + end + + def endpoint_name(model_name) + "vLLM - #{model_name}" + end end def normalize_model_params(model_params) diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 846f57b5..758071a9 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -7,7 +7,7 @@ # the target model and routes the completion request through the correct gateway. # # Use the .proxy method to instantiate an object. -# It chooses the best dialect and endpoint for the model you want to interact with. +# It chooses the correct dialect and endpoint for the model you want to interact with. # # Tests of modules that perform LLM calls can use .with_prepared_responses to return canned responses # instead of relying on WebMock stubs like we did in the past. @@ -17,27 +17,62 @@ module DiscourseAi class Llm UNKNOWN_MODEL = Class.new(StandardError) - def self.with_prepared_responses(responses) - @canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses) + class << self + def models_by_provider + # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. + # However, since they use the same URL/key settings, there's no reason to duplicate them. + { + aws_bedrock: %w[claude-instant-1 claude-2], + anthropic: %w[claude-instant-1 claude-2], + vllm: %w[ + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + StableBeluga2 + Upstage-Llama-2-*-instruct-v2 + Llama2-*-chat-hf + Llama2-chat-hf + ], + hugging_face: %w[ + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + StableBeluga2 + Upstage-Llama-2-*-instruct-v2 + Llama2-*-chat-hf + Llama2-chat-hf + ], + open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo], + google: %w[gemini-pro], + }.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? } + end - yield(@canned_response) - ensure - # Don't leak prepared response if there's an exception. - @canned_response = nil - end + def with_prepared_responses(responses) + @canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses) - def self.proxy(model_name) - dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name) + yield(@canned_response) + ensure + # Don't leak prepared response if there's an exception. + @canned_response = nil + end - return new(dialect_klass, @canned_response, model_name) if @canned_response + def proxy(model_name) + provider_and_model_name = model_name.split(":") - gateway = - DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new( - model_name, - dialect_klass.tokenizer, - ) + provider_name = provider_and_model_name.first + model_name_without_prov = provider_and_model_name[1..].join - new(dialect_klass, gateway, model_name) + dialect_klass = + DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) + + return new(dialect_klass, @canned_response, model_name) if @canned_response + + gateway = + DiscourseAi::Completions::Endpoints::Base.endpoint_for( + provider_name, + model_name_without_prov, + ).new(model_name_without_prov, dialect_klass.tokenizer) + + new(dialect_klass, gateway, model_name_without_prov) + end end def initialize(dialect_klass, gateway, model_name) diff --git a/lib/configuration/llm_dependency_validator.rb b/lib/configuration/llm_dependency_validator.rb new file mode 100644 index 00000000..56e46b59 --- /dev/null +++ b/lib/configuration/llm_dependency_validator.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +module DiscourseAi + module Configuration + class LlmDependencyValidator + def initialize(opts = {}) + @opts = opts + end + + def valid_value?(val) + return true if val == "f" + + SiteSetting.public_send(llm_dependency_setting_name).present? + end + + def error_message + I18n.t("discourse_ai.llm.configuration.set_llm_first", setting: llm_dependency_setting_name) + end + + def llm_dependency_setting_name + if @opts[:name] == :ai_embeddings_semantic_search_enabled + :ai_embeddings_semantic_search_hyde_model + else + :ai_helper_model + end + end + end + end +end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb new file mode 100644 index 00000000..ab9757bc --- /dev/null +++ b/lib/configuration/llm_enumerator.rb @@ -0,0 +1,25 @@ +# frozen_string_literal: true + +require "enum_site_setting" + +module DiscourseAi + module Configuration + class LlmEnumerator < ::EnumSiteSetting + def self.valid_value?(val) + true + end + + def self.values + @values ||= + DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| + endpoint = + DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) + + models.map do |model_name| + { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + end + end + end + end + end +end diff --git a/lib/configuration/llm_validator.rb b/lib/configuration/llm_validator.rb new file mode 100644 index 00000000..291bcca3 --- /dev/null +++ b/lib/configuration/llm_validator.rb @@ -0,0 +1,77 @@ +# frozen_string_literal: true + +module DiscourseAi + module Configuration + class LlmValidator + def initialize(opts = {}) + @opts = opts + end + + def valid_value?(val) + if val == "" + @parent_enabled = SiteSetting.public_send(parent_module_name) + return !@parent_enabled + end + + provider_and_model_name = val.split(":") + + provider_name = provider_and_model_name.first + model_name_without_prov = provider_and_model_name[1..].join + + endpoint = + DiscourseAi::Completions::Endpoints::Base.endpoint_for( + provider_name, + model_name_without_prov, + ) + + return false if endpoint.nil? + + if !endpoint.correctly_configured?(model_name_without_prov) + @endpoint = endpoint + return false + end + + if !can_talk_to_model?(val) + @unreachable = true + return false + end + + true + end + + def error_message + if @parent_enabled + return( + I18n.t( + "discourse_ai.llm.configuration.disable_module_first", + setting: parent_module_name, + ) + ) + end + + return(I18n.t("discourse_ai.llm.configuration.model_unreachable")) if @unreachable + + @endpoint&.configuration_hint + end + + def parent_module_name + if @opts[:name] == :ai_embeddings_semantic_search_hyde_model + :ai_embeddings_semantic_search_enabled + else + :composer_ai_helper_enabled + end + end + + private + + def can_talk_to_model?(model_name) + DiscourseAi::Completions::Llm + .proxy(model_name) + .generate("How much is 1 + 1?", user: nil) + .present? + rescue StandardError + false + end + end + end +end diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index fd986bb8..1e5ae232 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -5,22 +5,45 @@ module DiscourseAi class EntryPoint def inject_into(plugin) foldable_models = [ - Models::OpenAi.new("gpt-4", max_tokens: 8192), - Models::OpenAi.new("gpt-4-32k", max_tokens: 32_768), - Models::OpenAi.new("gpt-4-0125-preview", max_tokens: 100_000), - Models::OpenAi.new("gpt-3.5-turbo", max_tokens: 4096), - Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384), - Models::Anthropic.new("claude-2", max_tokens: 200_000), - Models::Anthropic.new("claude-instant-1", max_tokens: 100_000), - Models::Llama2.new("Llama2-chat-hf", max_tokens: SiteSetting.ai_hugging_face_token_limit), - Models::Llama2FineTunedOrcaStyle.new( - "StableBeluga2", + Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192), + Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768), + Models::OpenAi.new("open_ai:gpt-4-0125-preview", max_tokens: 100_000), + Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), + Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), + Models::Llama2.new( + "hugging_face:Llama2-chat-hf", max_tokens: SiteSetting.ai_hugging_face_token_limit, ), - Models::Gemini.new("gemini-pro", max_tokens: 32_768), - Models::Mixtral.new("mistralai/Mixtral-8x7B-Instruct-v0.1", max_tokens: 32_000), + Models::Llama2FineTunedOrcaStyle.new( + "hugging_face:StableBeluga2", + max_tokens: SiteSetting.ai_hugging_face_token_limit, + ), + Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), ] + claude_prov = "anthropic" + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + claude_prov = "aws_bedrock" + end + + foldable_models << Models::Anthropic.new("#{claude_prov}:claude-2", max_tokens: 200_000) + foldable_models << Models::Anthropic.new( + "#{claude_prov}:claude-instant-1", + max_tokens: 100_000, + ) + + mixtral_prov = "hugging_face" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + mixtral_prov = "vllm" + end + + foldable_models << Models::Mixtral.new( + "#{mixtral_prov}:mistralai/Mixtral-8x7B-Instruct-v0.1", + max_tokens: 32_000, + ) + foldable_models.each do |model| plugin.register_summarization_strategy(Strategies::FoldContent.new(model)) end diff --git a/lib/summarization/models/base.rb b/lib/summarization/models/base.rb index 00e4b84f..487950d8 100644 --- a/lib/summarization/models/base.rb +++ b/lib/summarization/models/base.rb @@ -4,8 +4,8 @@ module DiscourseAi module Summarization module Models class Base - def initialize(model, max_tokens:) - @model = model + def initialize(model_name, max_tokens:) + @model_name = model_name @max_tokens = max_tokens end @@ -25,7 +25,11 @@ module DiscourseAi max_tokens - reserved_tokens end - attr_reader :model, :max_tokens + def model + model_name.split(":").last + end + + attr_reader :model_name, :max_tokens protected diff --git a/lib/summarization/strategies/fold_content.rb b/lib/summarization/strategies/fold_content.rb index 384355d3..47f7b62b 100644 --- a/lib/summarization/strategies/fold_content.rb +++ b/lib/summarization/strategies/fold_content.rb @@ -19,7 +19,7 @@ module DiscourseAi def summarize(content, user, &on_partial_blk) opts = content.except(:contents) - llm = DiscourseAi::Completions::Llm.proxy(completion_model.model) + llm = DiscourseAi::Completions::Llm.proxy(completion_model.model_name) initial_chunks = rebalance_chunks( diff --git a/spec/jobs/regular/stream_post_helper_spec.rb b/spec/jobs/regular/stream_post_helper_spec.rb index 359bb61a..b842b684 100644 --- a/spec/jobs/regular/stream_post_helper_spec.rb +++ b/spec/jobs/regular/stream_post_helper_spec.rb @@ -3,6 +3,8 @@ RSpec.describe Jobs::StreamPostHelper do subject(:job) { described_class.new } + before { SiteSetting.ai_helper_model = "fake:fake" } + describe "#execute" do fab!(:topic) { Fabricate(:topic) } fab!(:post) do diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index f141bfd6..556f1810 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -5,7 +5,7 @@ RSpec.describe DiscourseAi::Completions::Llm do described_class.new( DiscourseAi::Completions::Dialects::OrcaStyle, canned_response, - "Upstage-Llama-2-*-instruct-v2", + "hugging_face:Upstage-Llama-2-*-instruct-v2", ) end @@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::Completions::Llm do describe ".proxy" do it "raises an exception when we can't proxy the model" do - fake_model = "unknown_v2" + fake_model = "unknown:unknown_v2" expect { described_class.proxy(fake_model) }.to( raise_error(DiscourseAi::Completions::Llm::UNKNOWN_MODEL), @@ -27,7 +27,7 @@ RSpec.describe DiscourseAi::Completions::Llm do DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10 end - let(:llm) { described_class.proxy("fake") } + let(:llm) { described_class.proxy("fake:fake") } let(:prompt) do DiscourseAi::Completions::Prompt.new( diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb index 46200fe6..7bbde56b 100644 --- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb +++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb @@ -6,7 +6,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do let(:prompts) { ["a pink cow", "a red cow"] } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/db_schema_spec.rb b/spec/lib/modules/ai_bot/tools/db_schema_spec.rb index 83f4cb29..f545477c 100644 --- a/spec/lib/modules/ai_bot/tools/db_schema_spec.rb +++ b/spec/lib/modules/ai_bot/tools/db_schema_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before { SiteSetting.ai_bot_enabled = true } describe "#process" do diff --git a/spec/lib/modules/ai_bot/tools/google_spec.rb b/spec/lib/modules/ai_bot/tools/google_spec.rb index 6a9a900a..4cb2d6d3 100644 --- a/spec/lib/modules/ai_bot/tools/google_spec.rb +++ b/spec/lib/modules/ai_bot/tools/google_spec.rb @@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do subject(:search) { described_class.new({ query: "some search term" }) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/image_spec.rb b/spec/lib/modules/ai_bot/tools/image_spec.rb index 9772a1f5..040c5230 100644 --- a/spec/lib/modules/ai_bot/tools/image_spec.rb +++ b/spec/lib/modules/ai_bot/tools/image_spec.rb @@ -3,7 +3,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } diff --git a/spec/lib/modules/ai_bot/tools/list_categories_spec.rb b/spec/lib/modules/ai_bot/tools/list_categories_spec.rb index 7218442d..64844a44 100644 --- a/spec/lib/modules/ai_bot/tools/list_categories_spec.rb +++ b/spec/lib/modules/ai_bot/tools/list_categories_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/list_tags_spec.rb b/spec/lib/modules/ai_bot/tools/list_tags_spec.rb index e2278baf..9cddf419 100644 --- a/spec/lib/modules/ai_bot/tools/list_tags_spec.rb +++ b/spec/lib/modules/ai_bot/tools/list_tags_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::ListTags do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before do SiteSetting.ai_bot_enabled = true diff --git a/spec/lib/modules/ai_bot/tools/read_spec.rb b/spec/lib/modules/ai_bot/tools/read_spec.rb index c04bab2f..309d3286 100644 --- a/spec/lib/modules/ai_bot/tools/read_spec.rb +++ b/spec/lib/modules/ai_bot/tools/read_spec.rb @@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } fab!(:parent_category) { Fabricate(:category, name: "animals") } fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") } diff --git a/spec/lib/modules/ai_bot/tools/search_settings_spec.rb b/spec/lib/modules/ai_bot/tools/search_settings_spec.rb index 59c68604..94763ae9 100644 --- a/spec/lib/modules/ai_bot/tools/search_settings_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_settings_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 61fada6b..066937c3 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -4,8 +4,10 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do before { SearchIndexer.enable } after { SearchIndexer.disable } + before { SiteSetting.ai_openai_api_key = "asd" } + let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } fab!(:admin) @@ -65,6 +67,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) } it "supports semantic search when enabled" do + SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake" SiteSetting.ai_embeddings_semantic_search_enabled = true SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" diff --git a/spec/lib/modules/ai_bot/tools/setting_context_spec.rb b/spec/lib/modules/ai_bot/tools/setting_context_spec.rb index 609422d0..1953947d 100644 --- a/spec/lib/modules/ai_bot/tools/setting_context_spec.rb +++ b/spec/lib/modules/ai_bot/tools/setting_context_spec.rb @@ -10,7 +10,7 @@ end RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/summarize_spec.rb b/spec/lib/modules/ai_bot/tools/summarize_spec.rb index 6a0795d8..1af327ee 100644 --- a/spec/lib/modules/ai_bot/tools/summarize_spec.rb +++ b/spec/lib/modules/ai_bot/tools/summarize_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Summarize do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:progress_blk) { Proc.new {} } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_bot/tools/time_spec.rb b/spec/lib/modules/ai_bot/tools/time_spec.rb index 90142698..6e0ca6fd 100644 --- a/spec/lib/modules/ai_bot/tools/time_spec.rb +++ b/spec/lib/modules/ai_bot/tools/time_spec.rb @@ -2,7 +2,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Time do let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } - let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") } + let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } before { SiteSetting.ai_bot_enabled = true } diff --git a/spec/lib/modules/ai_helper/assistant_spec.rb b/spec/lib/modules/ai_helper/assistant_spec.rb index 4f40aa04..875ae6a9 100644 --- a/spec/lib/modules/ai_helper/assistant_spec.rb +++ b/spec/lib/modules/ai_helper/assistant_spec.rb @@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do fab!(:user) { Fabricate(:user) } let(:prompt) { CompletionPrompt.find_by(id: mode) } + before { SiteSetting.ai_helper_model = "fake:fake" } + let(:english_text) { <<~STRING } To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer diff --git a/spec/lib/modules/ai_helper/chat_thread_titler_spec.rb b/spec/lib/modules/ai_helper/chat_thread_titler_spec.rb index f0c27f7a..3dc07252 100644 --- a/spec/lib/modules/ai_helper/chat_thread_titler_spec.rb +++ b/spec/lib/modules/ai_helper/chat_thread_titler_spec.rb @@ -3,6 +3,8 @@ RSpec.describe DiscourseAi::AiHelper::ChatThreadTitler do subject(:titler) { described_class.new(thread) } + before { SiteSetting.ai_helper_model = "fake:fake" } + fab!(:thread) { Fabricate(:chat_thread) } fab!(:user) { Fabricate(:user) } diff --git a/spec/lib/modules/ai_helper/painter_spec.rb b/spec/lib/modules/ai_helper/painter_spec.rb index bf11d24b..91126732 100644 --- a/spec/lib/modules/ai_helper/painter_spec.rb +++ b/spec/lib/modules/ai_helper/painter_spec.rb @@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::AiHelper::Painter do fab!(:user) { Fabricate(:user) } before do + SiteSetting.ai_helper_model = "fake:fake" SiteSetting.ai_stability_api_url = "https://api.stability.dev" SiteSetting.ai_stability_api_key = "abc" SiteSetting.ai_openai_api_key = "abc" diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb index 911064d4..c1bb8188 100644 --- a/spec/lib/modules/automation/llm_triage_spec.rb +++ b/spec/lib/modules/automation/llm_triage_spec.rb @@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do triage( post: post, - model: "gpt-4", + model: "fake:fake", hide_topic: true, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -20,25 +20,11 @@ describe DiscourseAi::Automation::LlmTriage do expect(post.topic.reload.visible).to eq(true) end - it "can hide topics on triage with claude" do + it "can hide topics on triage" do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "claude-2", - hide_topic: true, - system_prompt: "test %%POST%%", - search_for_text: "bad", - ) - end - - expect(post.topic.reload.visible).to eq(false) - end - - it "can hide topics on triage with claude" do - DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do - triage( - post: post, - model: "gpt-4", + model: "fake:fake", hide_topic: true, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -54,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "gpt-4", + model: "fake:fake", category_id: category.id, system_prompt: "test %%POST%%", search_for_text: "bad", @@ -69,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do triage( post: post, - model: "gpt-4", + model: "fake:fake", system_prompt: "test %%POST%%", search_for_text: "bad", canned_reply: "test canned reply 123", diff --git a/spec/lib/modules/automation/report_runner_spec.rb b/spec/lib/modules/automation/report_runner_spec.rb index ca424bf2..5e650641 100644 --- a/spec/lib/modules/automation/report_runner_spec.rb +++ b/spec/lib/modules/automation/report_runner_spec.rb @@ -22,7 +22,7 @@ module DiscourseAi sender_username: user.username, receivers: ["fake@discourse.com"], title: "test report %DATE%", - model: "gpt-4", + model: "fake:fake", category_ids: nil, tags: nil, allow_secure_categories: false, @@ -48,7 +48,7 @@ module DiscourseAi sender_username: user.username, receivers: [receiver.username], title: "test report", - model: "gpt-4", + model: "fake:fake", category_ids: nil, tags: nil, allow_secure_categories: false, diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb index ad18da71..07df3821 100644 --- a/spec/lib/modules/embeddings/semantic_search_spec.rb +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -7,6 +7,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do let(:query) { "test_query" } let(:subject) { described_class.new(Guardian.new(user)) } + before { SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake" } + describe "#search_for_topics" do let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } diff --git a/spec/lib/modules/summarization/strategies/fold_content_spec.rb b/spec/lib/modules/summarization/strategies/fold_content_spec.rb index eaff533e..0333dd45 100644 --- a/spec/lib/modules/summarization/strategies/fold_content_spec.rb +++ b/spec/lib/modules/summarization/strategies/fold_content_spec.rb @@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::Summarization::Strategies::FoldContent do end let(:model) do - DiscourseAi::Summarization::Models::OpenAi.new("gpt-4", max_tokens: model_tokens) + DiscourseAi::Summarization::Models::OpenAi.new("fake:fake", max_tokens: model_tokens) end let(:content) { { contents: [{ poster: "asd", id: 1, text: summarize_text }] } } diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 456fb6eb..5902ba1b 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -1,6 +1,8 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::AiHelper::AssistantController do + before { SiteSetting.ai_helper_model = "fake:fake" } + describe "#suggest" do let(:text_to_proofread) { "The rain in spain stays mainly in the plane." } let(:proofread_text) { "The rain in Spain, stays mainly in the Plane." } @@ -90,7 +92,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do A user wrote this TEXT - DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do |spy| + DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do post "/discourse-ai/ai-helper/suggest", params: { mode: CompletionPrompt::CUSTOM_PROMPT, @@ -101,7 +103,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do expect(response.status).to eq(200) expect(response.parsed_body["suggestions"].first).to eq(translated_text) expect(response.parsed_body["diff"]).to eq(expected_diff) - expect(spy.prompt.translate.last[:content]).to eq(expected_input) end end end diff --git a/spec/system/ai_helper/ai_composer_helper_spec.rb b/spec/system/ai_helper/ai_composer_helper_spec.rb index 3b99f779..dff111f6 100644 --- a/spec/system/ai_helper/ai_composer_helper_spec.rb +++ b/spec/system/ai_helper/ai_composer_helper_spec.rb @@ -6,6 +6,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do before do Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user) + SiteSetting.ai_helper_model = "fake:fake" SiteSetting.composer_ai_helper_enabled = true sign_in(user) end diff --git a/spec/system/ai_helper/ai_post_helper_spec.rb b/spec/system/ai_helper/ai_post_helper_spec.rb index 5b820e80..628bf23a 100644 --- a/spec/system/ai_helper/ai_post_helper_spec.rb +++ b/spec/system/ai_helper/ai_post_helper_spec.rb @@ -28,6 +28,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do before do Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user) + SiteSetting.ai_helper_model = "fake:fake" SiteSetting.composer_ai_helper_enabled = true sign_in(user) end diff --git a/spec/system/ai_helper/ai_split_topic_suggestion_spec.rb b/spec/system/ai_helper/ai_split_topic_suggestion_spec.rb index 4efe437a..b280a16a 100644 --- a/spec/system/ai_helper/ai_split_topic_suggestion_spec.rb +++ b/spec/system/ai_helper/ai_split_topic_suggestion_spec.rb @@ -38,6 +38,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do before do Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user) + SiteSetting.ai_helper_model = "fake:fake" SiteSetting.composer_ai_helper_enabled = true sign_in(user) end