UX: Validations to LLM-backed features (except AI Bot) (#436)

* UX: Validations to Llm-backed features (except AI Bot)

This change is part of an ongoing effort to prevent enabling a broken feature due to lack of configuration. We also want to explicit which provider we are going to use. For example, Claude models are available through AWS Bedrock and Anthropic, but the configuration differs.

Validations are:

* You must choose a model before enabling the feature.
* You must turn off the feature before setting the model to blank.
* You must configure each model settings before being able to select it.

* Add provider name to summarization options

* vLLM can technically support same models as HF

* Check we can talk to the selected model

* Check for Bedrock instead of anthropic as a site could have both creds setup
This commit is contained in:
Roman Rizzi 2024-01-29 16:04:25 -03:00 committed by GitHub
parent b2b01185f2
commit 0634b85a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 613 additions and 154 deletions

View File

@ -240,3 +240,14 @@ en:
anger: "Anger"
joy: "Joy"
disgust: "Disgust"
llm:
configuration:
disable_module_first: "You have to disable %{setting} first."
set_llm_first: "Set %{setting} first."
model_unreachable: "We couldn't get a response from this model. Check your settings first."
endpoints:
not_configured: "%{display_name} (not configured)"
configuration_hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure these settings were configured: %{settings}"

View File

@ -170,6 +170,7 @@ discourse_ai:
composer_ai_helper_enabled:
default: false
client: true
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
ai_helper_allowed_groups:
client: true
type: group_list
@ -181,17 +182,11 @@ discourse_ai:
default: false
client: true
ai_helper_model:
default: gpt-3.5-turbo
default: ""
allow_any: false
type: enum
choices:
- gpt-3.5-turbo
- gpt-4
- claude-2
- stable-beluga-2
- Llama2-chat-hf
- gemini-pro
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
enum: "DiscourseAi::Configuration::LlmEnumerator"
validator: "DiscourseAi::Configuration::LlmValidator"
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
@ -257,21 +252,13 @@ discourse_ai:
ai_embeddings_semantic_search_enabled:
default: false
client: true
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
ai_embeddings_semantic_search_hyde_model:
default: "gpt-3.5-turbo"
default: ""
type: enum
allow_any: false
choices:
- Llama2-*-chat-hf
- claude-instant-1
- claude-2
- gpt-3.5-turbo
- gpt-4
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- gemini-pro
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
enum: "DiscourseAi::Configuration::LlmEnumerator"
validator: "DiscourseAi::Configuration::LlmValidator"
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_endpoint_srv:

View File

@ -0,0 +1,91 @@
# frozen_string_literal: true
class ExplicitProviderBackwardsCompat < ActiveRecord::Migration[7.0]
def up
backfill_settings("composer_ai_helper_enabled", "ai_helper_model")
backfill_settings(
"ai_embeddings_semantic_search_enabled",
"ai_embeddings_semantic_search_hyde_model",
)
end
def down
raise ActiveRecord::IrreversibleMigration
end
def backfill_settings(feature_setting_name, llm_setting_name)
feature_enabled =
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: feature_setting_name,
).first == "t"
setting_value =
DB
.query_single(
"SELECT value FROM site_settings WHERE name = :llm_setting",
llm_setting: llm_setting_name,
)
.first
.to_s
providers = %w[aws_bedrock anthropic open_ai hugging_face vllm google]
# Sanity check to make sure we won't add provider twice.
return if providers.include?(setting_value.split(":").first)
if !setting_value && feature_enabled
# Enabled and using old default (gpt-3.5-turbo)
DB.exec(
"UPDATE site_settings SET value='open_ai:gpt-3.5-turbo' WHERE name=:llm_setting",
llm_setting: llm_setting_name,
)
elsif setting_value && !feature_enabled
# They'll have to choose an LLM model again before enabling the feature
DB.exec("DELETE FROM site_settings WHERE name=:llm_setting", llm_setting: llm_setting_name)
elsif setting_value && feature_enabled
DB.exec(
"UPDATE site_settings SET value=:new_value WHERE name=:llm_setting",
llm_setting: llm_setting_name,
new_value: append_provider(setting_value),
)
end
end
def append_provider(value)
open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo]
return "open_ai:#{value}" if open_ai_models.include?(value)
return "google:#{value}" if value == "gemini-pro"
hf_models = %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf]
return "hugging_face:#{value}" if hf_models.include?(value)
# Models available through multiple providers
claude_models = %w[claude-instant-1 claude-2]
if claude_models.include?(value)
has_bedrock_creds =
DB.query_single(
"SELECT value FROM site_settings WHERE name = 'ai_bedrock_secret_access_key' OR name = 'ai_bedrock_access_key_id' ",
).length > 0
if has_bedrock_creds
return "aws_bedrock:#{value}"
else
return "anthropic:#{value}"
end
end
mixtral_models = %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2]
if mixtral_models.include?(value)
vllm_configured =
DB.query_single(
"SELECT value FROM site_settings WHERE name = 'ai_vllm_endpoint_srv' OR name = 'ai_vllm_endpoint' ",
).length > 0
if vllm_configured
"vllm:#{value}"
else
"hugging_face:#{value}"
end
end
end
end

View File

@ -125,28 +125,39 @@ module DiscourseAi
end
def model(prefer_low_cost: false)
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
default_model =
case bot_user.id
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
"claude-2"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"gpt-4"
"open_ai:gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"gpt-4-turbo"
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"gpt-3.5-turbo-16k"
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
"mistralai/Mixtral-8x7B-Instruct-v0.1"
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"gemini-pro"
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake"
"fake:fake"
else
nil
end
if %w[gpt-4 gpt-4-turbo].include?(default_model) && prefer_low_cost
return "gpt-3.5-turbo-16k"
if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost
return "open_ai:gpt-3.5-turbo-16k"
end
default_model

View File

@ -9,6 +9,10 @@ module DiscourseAi
model_name == "fake"
end
def translate
""
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

View File

@ -4,8 +4,22 @@ module DiscourseAi
module Completions
module Endpoints
class Anthropic < Base
def self.can_contact?(model_name)
%w[claude-instant-1 claude-2].include?(model_name)
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name)
end
def dependant_setting_names
%w[ai_anthropic_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_anthropic_api_key.present?
end
def endpoint_name(model_name)
"Anthropic - #{model_name}"
end
end
def normalize_model_params(model_params)

View File

@ -6,11 +6,24 @@ module DiscourseAi
module Completions
module Endpoints
class AwsBedrock < Base
def self.can_contact?(model_name)
%w[claude-instant-1 claude-2].include?(model_name) &&
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name)
end
def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end
def correctly_configured?(_model_name)
SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
end
def endpoint_name(model_name)
"AWS Bedrock - #{model_name}"
end
end
def normalize_model_params(model_params)

View File

@ -7,29 +7,54 @@ module DiscourseAi
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
def self.endpoint_for(model_name)
# Order is important.
# Bedrock has priority over Anthropic if credentials are present.
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
]
class << self
def endpoint_for(provider_name, model_name)
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
]
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(provider_name, model_name)
end
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(model_name)
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.llm.endpoints.configuration_hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def self.can_contact?(_model_name)
raise NotImplementedError
def display_name(model_name)
to_display = endpoint_name(model_name)
return to_display if correctly_configured?(model_name)
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
end
def dependant_setting_names
raise NotImplementedError
end
def endpoint_name(_model_name)
raise NotImplementedError
end
def can_contact?(_endpoint_name, _model_name)
raise NotImplementedError
end
end
def initialize(model_name, tokenizer)

View File

@ -6,7 +6,7 @@ module DiscourseAi
class CannedResponse
CANNED_RESPONSE_ERROR = Class.new(StandardError)
def self.can_contact?(_)
def self.can_contact?(_, _)
Rails.env.test?
end

View File

@ -4,8 +4,18 @@ module DiscourseAi
module Completions
module Endpoints
class Fake < Base
def self.can_contact?(model_name)
model_name == "fake"
class << self
def can_contact?(_endpoint_name, model_name)
model_name == "fake"
end
def correctly_configured?(_model_name)
true
end
def endpoint_name(_model_name)
"Test - fake model"
end
end
STOCK_CONTENT = <<~TEXT

View File

@ -4,8 +4,23 @@ module DiscourseAi
module Completions
module Endpoints
class Gemini < Base
def self.can_contact?(model_name)
%w[gemini-pro].include?(model_name)
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "google"
%w[gemini-pro].include?(model_name)
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_gemini_api_key.present?
end
def endpoint_name(model_name)
"Google - #{model_name}"
end
end
def default_options

View File

@ -4,15 +4,31 @@ module DiscourseAi
module Completions
module Endpoints
class HuggingFace < Base
def self.can_contact?(model_name)
%w[
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
].include?(model_name) && SiteSetting.ai_hugging_face_api_url.present?
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "hugging_face"
%w[
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
].include?(model_name)
end
def dependant_setting_names
%w[ai_hugging_face_api_url]
end
def correctly_configured?(_model_name)
SiteSetting.ai_hugging_face_api_url.present?
end
def endpoint_name(model_name)
"Hugging Face - #{model_name}"
end
end
def default_options

View File

@ -4,15 +4,62 @@ module DiscourseAi
module Completions
module Endpoints
class OpenAi < Base
def self.can_contact?(model_name)
%w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-0125-preview
gpt-4-turbo
].include?(model_name)
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "open_ai"
%w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-0125-preview
gpt-4-turbo
].include?(model_name)
end
def dependant_setting_names
%w[
ai_openai_api_key
ai_openai_gpt4_32k_url
ai_openai_gpt4_turbo_url
ai_openai_gpt4_url
ai_openai_gpt4_url
ai_openai_gpt35_16k_url
ai_openai_gpt35_url
]
end
def correctly_configured?(model_name)
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
end
def has_url?(model)
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
if model.include?("1106") || model.include?("turbo")
SiteSetting.ai_openai_gpt4_turbo_url
else
SiteSetting.ai_openai_gpt4_url
end
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
url.present?
end
def endpoint_name(model_name)
"OpenAI - #{model_name}"
end
end
def normalize_model_params(model_params)

View File

@ -4,10 +4,30 @@ module DiscourseAi
module Completions
module Endpoints
class Vllm < Base
def self.can_contact?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "vllm" &&
%w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
].include?(model_name)
end
def dependant_setting_names
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
end
def endpoint_name(model_name)
"vLLM - #{model_name}"
end
end
def normalize_model_params(model_params)

View File

@ -7,7 +7,7 @@
# the target model and routes the completion request through the correct gateway.
#
# Use the .proxy method to instantiate an object.
# It chooses the best dialect and endpoint for the model you want to interact with.
# It chooses the correct dialect and endpoint for the model you want to interact with.
#
# Tests of modules that perform LLM calls can use .with_prepared_responses to return canned responses
# instead of relying on WebMock stubs like we did in the past.
@ -17,27 +17,62 @@ module DiscourseAi
class Llm
UNKNOWN_MODEL = Class.new(StandardError)
def self.with_prepared_responses(responses)
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
class << self
def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them.
{
aws_bedrock: %w[claude-instant-1 claude-2],
anthropic: %w[claude-instant-1 claude-2],
vllm: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
],
hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
],
open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo],
google: %w[gemini-pro],
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
end
yield(@canned_response)
ensure
# Don't leak prepared response if there's an exception.
@canned_response = nil
end
def with_prepared_responses(responses)
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
def self.proxy(model_name)
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
yield(@canned_response)
ensure
# Don't leak prepared response if there's an exception.
@canned_response = nil
end
return new(dialect_klass, @canned_response, model_name) if @canned_response
def proxy(model_name)
provider_and_model_name = model_name.split(":")
gateway =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
model_name,
dialect_klass.tokenizer,
)
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
new(dialect_klass, gateway, model_name)
dialect_klass =
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
return new(dialect_klass, @canned_response, model_name) if @canned_response
gateway =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name,
model_name_without_prov,
).new(model_name_without_prov, dialect_klass.tokenizer)
new(dialect_klass, gateway, model_name_without_prov)
end
end
def initialize(dialect_klass, gateway, model_name)

View File

@ -0,0 +1,29 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class LlmDependencyValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if val == "f"
SiteSetting.public_send(llm_dependency_setting_name).present?
end
def error_message
I18n.t("discourse_ai.llm.configuration.set_llm_first", setting: llm_dependency_setting_name)
end
def llm_dependency_setting_name
if @opts[:name] == :ai_embeddings_semantic_search_enabled
:ai_embeddings_semantic_search_hyde_model
else
:ai_helper_model
end
end
end
end
end

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
require "enum_site_setting"
module DiscourseAi
module Configuration
class LlmEnumerator < ::EnumSiteSetting
def self.valid_value?(val)
true
end
def self.values
@values ||=
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
models.map do |model_name|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
end
end
end
end
end
end

View File

@ -0,0 +1,77 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class LlmValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
if val == ""
@parent_enabled = SiteSetting.public_send(parent_module_name)
return !@parent_enabled
end
provider_and_model_name = val.split(":")
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name,
model_name_without_prov,
)
return false if endpoint.nil?
if !endpoint.correctly_configured?(model_name_without_prov)
@endpoint = endpoint
return false
end
if !can_talk_to_model?(val)
@unreachable = true
return false
end
true
end
def error_message
if @parent_enabled
return(
I18n.t(
"discourse_ai.llm.configuration.disable_module_first",
setting: parent_module_name,
)
)
end
return(I18n.t("discourse_ai.llm.configuration.model_unreachable")) if @unreachable
@endpoint&.configuration_hint
end
def parent_module_name
if @opts[:name] == :ai_embeddings_semantic_search_hyde_model
:ai_embeddings_semantic_search_enabled
else
:composer_ai_helper_enabled
end
end
private
def can_talk_to_model?(model_name)
DiscourseAi::Completions::Llm
.proxy(model_name)
.generate("How much is 1 + 1?", user: nil)
.present?
rescue StandardError
false
end
end
end
end

View File

@ -5,22 +5,45 @@ module DiscourseAi
class EntryPoint
def inject_into(plugin)
foldable_models = [
Models::OpenAi.new("gpt-4", max_tokens: 8192),
Models::OpenAi.new("gpt-4-32k", max_tokens: 32_768),
Models::OpenAi.new("gpt-4-0125-preview", max_tokens: 100_000),
Models::OpenAi.new("gpt-3.5-turbo", max_tokens: 4096),
Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Anthropic.new("claude-2", max_tokens: 200_000),
Models::Anthropic.new("claude-instant-1", max_tokens: 100_000),
Models::Llama2.new("Llama2-chat-hf", max_tokens: SiteSetting.ai_hugging_face_token_limit),
Models::Llama2FineTunedOrcaStyle.new(
"StableBeluga2",
Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192),
Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768),
Models::OpenAi.new("open_ai:gpt-4-0125-preview", max_tokens: 100_000),
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Llama2.new(
"hugging_face:Llama2-chat-hf",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Gemini.new("gemini-pro", max_tokens: 32_768),
Models::Mixtral.new("mistralai/Mixtral-8x7B-Instruct-v0.1", max_tokens: 32_000),
Models::Llama2FineTunedOrcaStyle.new(
"hugging_face:StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
]
claude_prov = "anthropic"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
claude_prov = "aws_bedrock"
end
foldable_models << Models::Anthropic.new("#{claude_prov}:claude-2", max_tokens: 200_000)
foldable_models << Models::Anthropic.new(
"#{claude_prov}:claude-instant-1",
max_tokens: 100_000,
)
mixtral_prov = "hugging_face"
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
mixtral_prov = "vllm"
end
foldable_models << Models::Mixtral.new(
"#{mixtral_prov}:mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens: 32_000,
)
foldable_models.each do |model|
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
end

View File

@ -4,8 +4,8 @@ module DiscourseAi
module Summarization
module Models
class Base
def initialize(model, max_tokens:)
@model = model
def initialize(model_name, max_tokens:)
@model_name = model_name
@max_tokens = max_tokens
end
@ -25,7 +25,11 @@ module DiscourseAi
max_tokens - reserved_tokens
end
attr_reader :model, :max_tokens
def model
model_name.split(":").last
end
attr_reader :model_name, :max_tokens
protected

View File

@ -19,7 +19,7 @@ module DiscourseAi
def summarize(content, user, &on_partial_blk)
opts = content.except(:contents)
llm = DiscourseAi::Completions::Llm.proxy(completion_model.model)
llm = DiscourseAi::Completions::Llm.proxy(completion_model.model_name)
initial_chunks =
rebalance_chunks(

View File

@ -3,6 +3,8 @@
RSpec.describe Jobs::StreamPostHelper do
subject(:job) { described_class.new }
before { SiteSetting.ai_helper_model = "fake:fake" }
describe "#execute" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) do

View File

@ -5,7 +5,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
described_class.new(
DiscourseAi::Completions::Dialects::OrcaStyle,
canned_response,
"Upstage-Llama-2-*-instruct-v2",
"hugging_face:Upstage-Llama-2-*-instruct-v2",
)
end
@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
describe ".proxy" do
it "raises an exception when we can't proxy the model" do
fake_model = "unknown_v2"
fake_model = "unknown:unknown_v2"
expect { described_class.proxy(fake_model) }.to(
raise_error(DiscourseAi::Completions::Llm::UNKNOWN_MODEL),
@ -27,7 +27,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
end
let(:llm) { described_class.proxy("fake") }
let(:llm) { described_class.proxy("fake:fake") }
let(:prompt) do
DiscourseAi::Completions::Prompt.new(

View File

@ -6,7 +6,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
let(:prompts) { ["a pink cow", "a red cow"] }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do

View File

@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do
subject(:search) { described_class.new({ query: "some search term" }) }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -3,7 +3,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::Image do
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before do
SiteSetting.ai_bot_enabled = true

View File

@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:parent_category) { Fabricate(:category, name: "animals") }
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -4,8 +4,10 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
before { SiteSetting.ai_openai_api_key = "asd" }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
fab!(:admin)
@ -65,6 +67,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
it "supports semantic search when enabled" do
SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake"
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"

View File

@ -10,7 +10,7 @@ end
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -2,7 +2,7 @@
RSpec.describe DiscourseAi::AiBot::Tools::Time do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
fab!(:user) { Fabricate(:user) }
let(:prompt) { CompletionPrompt.find_by(id: mode) }
before { SiteSetting.ai_helper_model = "fake:fake" }
let(:english_text) { <<~STRING }
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer

View File

@ -3,6 +3,8 @@
RSpec.describe DiscourseAi::AiHelper::ChatThreadTitler do
subject(:titler) { described_class.new(thread) }
before { SiteSetting.ai_helper_model = "fake:fake" }
fab!(:thread) { Fabricate(:chat_thread) }
fab!(:user) { Fabricate(:user) }

View File

@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
fab!(:user) { Fabricate(:user) }
before do
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
SiteSetting.ai_stability_api_key = "abc"
SiteSetting.ai_openai_api_key = "abc"

View File

@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
triage(
post: post,
model: "gpt-4",
model: "fake:fake",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -20,25 +20,11 @@ describe DiscourseAi::Automation::LlmTriage do
expect(post.topic.reload.visible).to eq(true)
end
it "can hide topics on triage with claude" do
it "can hide topics on triage" do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "claude-2",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
)
end
expect(post.topic.reload.visible).to eq(false)
end
it "can hide topics on triage with claude" do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "fake:fake",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -54,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "fake:fake",
category_id: category.id,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -69,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "fake:fake",
system_prompt: "test %%POST%%",
search_for_text: "bad",
canned_reply: "test canned reply 123",

View File

@ -22,7 +22,7 @@ module DiscourseAi
sender_username: user.username,
receivers: ["fake@discourse.com"],
title: "test report %DATE%",
model: "gpt-4",
model: "fake:fake",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -48,7 +48,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "gpt-4",
model: "fake:fake",
category_ids: nil,
tags: nil,
allow_secure_categories: false,

View File

@ -7,6 +7,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
let(:query) { "test_query" }
let(:subject) { described_class.new(Guardian.new(user)) }
before { SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake" }
describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }

View File

@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::Summarization::Strategies::FoldContent do
end
let(:model) do
DiscourseAi::Summarization::Models::OpenAi.new("gpt-4", max_tokens: model_tokens)
DiscourseAi::Summarization::Models::OpenAi.new("fake:fake", max_tokens: model_tokens)
end
let(:content) { { contents: [{ poster: "asd", id: 1, text: summarize_text }] } }

View File

@ -1,6 +1,8 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiHelper::AssistantController do
before { SiteSetting.ai_helper_model = "fake:fake" }
describe "#suggest" do
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
let(:proofread_text) { "The rain in Spain, stays mainly in the Plane." }
@ -90,7 +92,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
A user wrote this</input>
TEXT
DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do |spy|
DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do
post "/discourse-ai/ai-helper/suggest",
params: {
mode: CompletionPrompt::CUSTOM_PROMPT,
@ -101,7 +103,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
expect(response.status).to eq(200)
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
expect(response.parsed_body["diff"]).to eq(expected_diff)
expect(spy.prompt.translate.last[:content]).to eq(expected_input)
end
end
end

View File

@ -6,6 +6,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
before do
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.composer_ai_helper_enabled = true
sign_in(user)
end

View File

@ -28,6 +28,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do
before do
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.composer_ai_helper_enabled = true
sign_in(user)
end

View File

@ -38,6 +38,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do
before do
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.composer_ai_helper_enabled = true
sign_in(user)
end