UX: Validations to LLM-backed features (except AI Bot) (#436)
* UX: Validations to Llm-backed features (except AI Bot) This change is part of an ongoing effort to prevent enabling a broken feature due to lack of configuration. We also want to explicit which provider we are going to use. For example, Claude models are available through AWS Bedrock and Anthropic, but the configuration differs. Validations are: * You must choose a model before enabling the feature. * You must turn off the feature before setting the model to blank. * You must configure each model settings before being able to select it. * Add provider name to summarization options * vLLM can technically support same models as HF * Check we can talk to the selected model * Check for Bedrock instead of anthropic as a site could have both creds setup
This commit is contained in:
parent
b2b01185f2
commit
0634b85a81
|
@ -240,3 +240,14 @@ en:
|
||||||
anger: "Anger"
|
anger: "Anger"
|
||||||
joy: "Joy"
|
joy: "Joy"
|
||||||
disgust: "Disgust"
|
disgust: "Disgust"
|
||||||
|
|
||||||
|
llm:
|
||||||
|
configuration:
|
||||||
|
disable_module_first: "You have to disable %{setting} first."
|
||||||
|
set_llm_first: "Set %{setting} first."
|
||||||
|
model_unreachable: "We couldn't get a response from this model. Check your settings first."
|
||||||
|
endpoints:
|
||||||
|
not_configured: "%{display_name} (not configured)"
|
||||||
|
configuration_hint:
|
||||||
|
one: "Make sure the `%{settings}` setting was configured."
|
||||||
|
other: "Make sure these settings were configured: %{settings}"
|
||||||
|
|
|
@ -170,6 +170,7 @@ discourse_ai:
|
||||||
composer_ai_helper_enabled:
|
composer_ai_helper_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
|
||||||
ai_helper_allowed_groups:
|
ai_helper_allowed_groups:
|
||||||
client: true
|
client: true
|
||||||
type: group_list
|
type: group_list
|
||||||
|
@ -181,17 +182,11 @@ discourse_ai:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
ai_helper_model:
|
ai_helper_model:
|
||||||
default: gpt-3.5-turbo
|
default: ""
|
||||||
|
allow_any: false
|
||||||
type: enum
|
type: enum
|
||||||
choices:
|
enum: "DiscourseAi::Configuration::LlmEnumerator"
|
||||||
- gpt-3.5-turbo
|
validator: "DiscourseAi::Configuration::LlmValidator"
|
||||||
- gpt-4
|
|
||||||
- claude-2
|
|
||||||
- stable-beluga-2
|
|
||||||
- Llama2-chat-hf
|
|
||||||
- gemini-pro
|
|
||||||
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
|
||||||
- mistralai/Mistral-7B-Instruct-v0.2
|
|
||||||
ai_helper_custom_prompts_allowed_groups:
|
ai_helper_custom_prompts_allowed_groups:
|
||||||
client: true
|
client: true
|
||||||
type: group_list
|
type: group_list
|
||||||
|
@ -257,21 +252,13 @@ discourse_ai:
|
||||||
ai_embeddings_semantic_search_enabled:
|
ai_embeddings_semantic_search_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
|
validator: "DiscourseAi::Configuration::LlmDependencyValidator"
|
||||||
ai_embeddings_semantic_search_hyde_model:
|
ai_embeddings_semantic_search_hyde_model:
|
||||||
default: "gpt-3.5-turbo"
|
default: ""
|
||||||
type: enum
|
type: enum
|
||||||
allow_any: false
|
allow_any: false
|
||||||
choices:
|
enum: "DiscourseAi::Configuration::LlmEnumerator"
|
||||||
- Llama2-*-chat-hf
|
validator: "DiscourseAi::Configuration::LlmValidator"
|
||||||
- claude-instant-1
|
|
||||||
- claude-2
|
|
||||||
- gpt-3.5-turbo
|
|
||||||
- gpt-4
|
|
||||||
- StableBeluga2
|
|
||||||
- Upstage-Llama-2-*-instruct-v2
|
|
||||||
- gemini-pro
|
|
||||||
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
|
||||||
- mistralai/Mistral-7B-Instruct-v0.2
|
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: ""
|
ai_summarization_discourse_service_api_endpoint: ""
|
||||||
ai_summarization_discourse_service_api_endpoint_srv:
|
ai_summarization_discourse_service_api_endpoint_srv:
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class ExplicitProviderBackwardsCompat < ActiveRecord::Migration[7.0]
|
||||||
|
def up
|
||||||
|
backfill_settings("composer_ai_helper_enabled", "ai_helper_model")
|
||||||
|
backfill_settings(
|
||||||
|
"ai_embeddings_semantic_search_enabled",
|
||||||
|
"ai_embeddings_semantic_search_hyde_model",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
raise ActiveRecord::IrreversibleMigration
|
||||||
|
end
|
||||||
|
|
||||||
|
def backfill_settings(feature_setting_name, llm_setting_name)
|
||||||
|
feature_enabled =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT value FROM site_settings WHERE name = :setting_name",
|
||||||
|
setting_name: feature_setting_name,
|
||||||
|
).first == "t"
|
||||||
|
|
||||||
|
setting_value =
|
||||||
|
DB
|
||||||
|
.query_single(
|
||||||
|
"SELECT value FROM site_settings WHERE name = :llm_setting",
|
||||||
|
llm_setting: llm_setting_name,
|
||||||
|
)
|
||||||
|
.first
|
||||||
|
.to_s
|
||||||
|
|
||||||
|
providers = %w[aws_bedrock anthropic open_ai hugging_face vllm google]
|
||||||
|
# Sanity check to make sure we won't add provider twice.
|
||||||
|
return if providers.include?(setting_value.split(":").first)
|
||||||
|
|
||||||
|
if !setting_value && feature_enabled
|
||||||
|
# Enabled and using old default (gpt-3.5-turbo)
|
||||||
|
DB.exec(
|
||||||
|
"UPDATE site_settings SET value='open_ai:gpt-3.5-turbo' WHERE name=:llm_setting",
|
||||||
|
llm_setting: llm_setting_name,
|
||||||
|
)
|
||||||
|
elsif setting_value && !feature_enabled
|
||||||
|
# They'll have to choose an LLM model again before enabling the feature
|
||||||
|
DB.exec("DELETE FROM site_settings WHERE name=:llm_setting", llm_setting: llm_setting_name)
|
||||||
|
elsif setting_value && feature_enabled
|
||||||
|
DB.exec(
|
||||||
|
"UPDATE site_settings SET value=:new_value WHERE name=:llm_setting",
|
||||||
|
llm_setting: llm_setting_name,
|
||||||
|
new_value: append_provider(setting_value),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def append_provider(value)
|
||||||
|
open_ai_models = %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo]
|
||||||
|
return "open_ai:#{value}" if open_ai_models.include?(value)
|
||||||
|
return "google:#{value}" if value == "gemini-pro"
|
||||||
|
|
||||||
|
hf_models = %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf]
|
||||||
|
return "hugging_face:#{value}" if hf_models.include?(value)
|
||||||
|
|
||||||
|
# Models available through multiple providers
|
||||||
|
claude_models = %w[claude-instant-1 claude-2]
|
||||||
|
if claude_models.include?(value)
|
||||||
|
has_bedrock_creds =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT value FROM site_settings WHERE name = 'ai_bedrock_secret_access_key' OR name = 'ai_bedrock_access_key_id' ",
|
||||||
|
).length > 0
|
||||||
|
|
||||||
|
if has_bedrock_creds
|
||||||
|
return "aws_bedrock:#{value}"
|
||||||
|
else
|
||||||
|
return "anthropic:#{value}"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
mixtral_models = %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2]
|
||||||
|
if mixtral_models.include?(value)
|
||||||
|
vllm_configured =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT value FROM site_settings WHERE name = 'ai_vllm_endpoint_srv' OR name = 'ai_vllm_endpoint' ",
|
||||||
|
).length > 0
|
||||||
|
|
||||||
|
if vllm_configured
|
||||||
|
"vllm:#{value}"
|
||||||
|
else
|
||||||
|
"hugging_face:#{value}"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -125,28 +125,39 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def model(prefer_low_cost: false)
|
def model(prefer_low_cost: false)
|
||||||
|
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
|
||||||
default_model =
|
default_model =
|
||||||
case bot_user.id
|
case bot_user.id
|
||||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||||
"claude-2"
|
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||||
|
"aws_bedrock:claude-2"
|
||||||
|
else
|
||||||
|
"anthropic:claude-2"
|
||||||
|
end
|
||||||
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||||
"gpt-4"
|
"open_ai:gpt-4"
|
||||||
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||||
"gpt-4-turbo"
|
"open_ai:gpt-4-turbo"
|
||||||
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||||
"gpt-3.5-turbo-16k"
|
"open_ai:gpt-3.5-turbo-16k"
|
||||||
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
|
||||||
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
)
|
||||||
|
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
else
|
||||||
|
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
end
|
||||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||||
"gemini-pro"
|
"google:gemini-pro"
|
||||||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||||
"fake"
|
"fake:fake"
|
||||||
else
|
else
|
||||||
nil
|
nil
|
||||||
end
|
end
|
||||||
|
|
||||||
if %w[gpt-4 gpt-4-turbo].include?(default_model) && prefer_low_cost
|
if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost
|
||||||
return "gpt-3.5-turbo-16k"
|
return "open_ai:gpt-3.5-turbo-16k"
|
||||||
end
|
end
|
||||||
|
|
||||||
default_model
|
default_model
|
||||||
|
|
|
@ -9,6 +9,10 @@ module DiscourseAi
|
||||||
model_name == "fake"
|
model_name == "fake"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
""
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -4,8 +4,22 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Anthropic < Base
|
class Anthropic < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[claude-instant-1 claude-2].include?(model_name)
|
def can_contact?(endpoint_name, model_name)
|
||||||
|
endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_anthropic_api_key]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
|
SiteSetting.ai_anthropic_api_key.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"Anthropic - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
|
|
@ -6,11 +6,24 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class AwsBedrock < Base
|
class AwsBedrock < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[claude-instant-1 claude-2].include?(model_name) &&
|
def can_contact?(endpoint_name, model_name)
|
||||||
|
endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||||
SiteSetting.ai_bedrock_region.present?
|
SiteSetting.ai_bedrock_region.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"AWS Bedrock - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
|
|
@ -7,29 +7,54 @@ module DiscourseAi
|
||||||
CompletionFailed = Class.new(StandardError)
|
CompletionFailed = Class.new(StandardError)
|
||||||
TIMEOUT = 60
|
TIMEOUT = 60
|
||||||
|
|
||||||
def self.endpoint_for(model_name)
|
class << self
|
||||||
# Order is important.
|
def endpoint_for(provider_name, model_name)
|
||||||
# Bedrock has priority over Anthropic if credentials are present.
|
endpoints = [
|
||||||
endpoints = [
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
||||||
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||||
DiscourseAi::Completions::Endpoints::HuggingFace,
|
DiscourseAi::Completions::Endpoints::Gemini,
|
||||||
DiscourseAi::Completions::Endpoints::Gemini,
|
DiscourseAi::Completions::Endpoints::Vllm,
|
||||||
DiscourseAi::Completions::Endpoints::Vllm,
|
]
|
||||||
]
|
|
||||||
|
|
||||||
if Rails.env.test? || Rails.env.development?
|
if Rails.env.test? || Rails.env.development?
|
||||||
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
||||||
|
end
|
||||||
|
|
||||||
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||||
|
ek.can_contact?(provider_name, model_name)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
def configuration_hint
|
||||||
ek.can_contact?(model_name)
|
settings = dependant_setting_names
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.llm.endpoints.configuration_hint",
|
||||||
|
settings: settings.join(", "),
|
||||||
|
count: settings.length,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
def self.can_contact?(_model_name)
|
def display_name(model_name)
|
||||||
raise NotImplementedError
|
to_display = endpoint_name(model_name)
|
||||||
|
|
||||||
|
return to_display if correctly_configured?(model_name)
|
||||||
|
|
||||||
|
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(_model_name)
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
|
|
||||||
|
def can_contact?(_endpoint_name, _model_name)
|
||||||
|
raise NotImplementedError
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(model_name, tokenizer)
|
def initialize(model_name, tokenizer)
|
||||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
||||||
class CannedResponse
|
class CannedResponse
|
||||||
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
||||||
|
|
||||||
def self.can_contact?(_)
|
def self.can_contact?(_, _)
|
||||||
Rails.env.test?
|
Rails.env.test?
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,18 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Fake < Base
|
class Fake < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
model_name == "fake"
|
def can_contact?(_endpoint_name, model_name)
|
||||||
|
model_name == "fake"
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(_model_name)
|
||||||
|
"Test - fake model"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
STOCK_CONTENT = <<~TEXT
|
STOCK_CONTENT = <<~TEXT
|
||||||
|
|
|
@ -4,8 +4,23 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Gemini < Base
|
class Gemini < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[gemini-pro].include?(model_name)
|
def can_contact?(endpoint_name, model_name)
|
||||||
|
return false unless endpoint_name == "google"
|
||||||
|
%w[gemini-pro].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_gemini_api_key]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
|
SiteSetting.ai_gemini_api_key.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"Google - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
|
|
|
@ -4,15 +4,31 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class HuggingFace < Base
|
class HuggingFace < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[
|
def can_contact?(endpoint_name, model_name)
|
||||||
StableBeluga2
|
return false unless endpoint_name == "hugging_face"
|
||||||
Upstage-Llama-2-*-instruct-v2
|
|
||||||
Llama2-*-chat-hf
|
%w[
|
||||||
Llama2-chat-hf
|
StableBeluga2
|
||||||
mistralai/Mixtral-8x7B-Instruct-v0.1
|
Upstage-Llama-2-*-instruct-v2
|
||||||
mistralai/Mistral-7B-Instruct-v0.2
|
Llama2-*-chat-hf
|
||||||
].include?(model_name) && SiteSetting.ai_hugging_face_api_url.present?
|
Llama2-chat-hf
|
||||||
|
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_hugging_face_api_url]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
|
SiteSetting.ai_hugging_face_api_url.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"Hugging Face - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
|
|
|
@ -4,15 +4,62 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class OpenAi < Base
|
class OpenAi < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[
|
def can_contact?(endpoint_name, model_name)
|
||||||
gpt-3.5-turbo
|
return false unless endpoint_name == "open_ai"
|
||||||
gpt-4
|
|
||||||
gpt-3.5-turbo-16k
|
%w[
|
||||||
gpt-4-32k
|
gpt-3.5-turbo
|
||||||
gpt-4-0125-preview
|
gpt-4
|
||||||
gpt-4-turbo
|
gpt-3.5-turbo-16k
|
||||||
].include?(model_name)
|
gpt-4-32k
|
||||||
|
gpt-4-0125-preview
|
||||||
|
gpt-4-turbo
|
||||||
|
].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[
|
||||||
|
ai_openai_api_key
|
||||||
|
ai_openai_gpt4_32k_url
|
||||||
|
ai_openai_gpt4_turbo_url
|
||||||
|
ai_openai_gpt4_url
|
||||||
|
ai_openai_gpt4_url
|
||||||
|
ai_openai_gpt35_16k_url
|
||||||
|
ai_openai_gpt35_url
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(model_name)
|
||||||
|
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def has_url?(model)
|
||||||
|
url =
|
||||||
|
if model.include?("gpt-4")
|
||||||
|
if model.include?("32k")
|
||||||
|
SiteSetting.ai_openai_gpt4_32k_url
|
||||||
|
else
|
||||||
|
if model.include?("1106") || model.include?("turbo")
|
||||||
|
SiteSetting.ai_openai_gpt4_turbo_url
|
||||||
|
else
|
||||||
|
SiteSetting.ai_openai_gpt4_url
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if model.include?("16k")
|
||||||
|
SiteSetting.ai_openai_gpt35_16k_url
|
||||||
|
else
|
||||||
|
SiteSetting.ai_openai_gpt35_url
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
url.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"OpenAI - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
|
|
@ -4,10 +4,30 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Vllm < Base
|
class Vllm < Base
|
||||||
def self.can_contact?(model_name)
|
class << self
|
||||||
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
def can_contact?(endpoint_name, model_name)
|
||||||
model_name,
|
endpoint_name == "vllm" &&
|
||||||
)
|
%w[
|
||||||
|
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
StableBeluga2
|
||||||
|
Upstage-Llama-2-*-instruct-v2
|
||||||
|
Llama2-*-chat-hf
|
||||||
|
Llama2-chat-hf
|
||||||
|
].include?(model_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def dependant_setting_names
|
||||||
|
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?(_model_name)
|
||||||
|
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def endpoint_name(model_name)
|
||||||
|
"vLLM - #{model_name}"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
# the target model and routes the completion request through the correct gateway.
|
# the target model and routes the completion request through the correct gateway.
|
||||||
#
|
#
|
||||||
# Use the .proxy method to instantiate an object.
|
# Use the .proxy method to instantiate an object.
|
||||||
# It chooses the best dialect and endpoint for the model you want to interact with.
|
# It chooses the correct dialect and endpoint for the model you want to interact with.
|
||||||
#
|
#
|
||||||
# Tests of modules that perform LLM calls can use .with_prepared_responses to return canned responses
|
# Tests of modules that perform LLM calls can use .with_prepared_responses to return canned responses
|
||||||
# instead of relying on WebMock stubs like we did in the past.
|
# instead of relying on WebMock stubs like we did in the past.
|
||||||
|
@ -17,27 +17,62 @@ module DiscourseAi
|
||||||
class Llm
|
class Llm
|
||||||
UNKNOWN_MODEL = Class.new(StandardError)
|
UNKNOWN_MODEL = Class.new(StandardError)
|
||||||
|
|
||||||
def self.with_prepared_responses(responses)
|
class << self
|
||||||
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
def models_by_provider
|
||||||
|
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
|
||||||
|
# However, since they use the same URL/key settings, there's no reason to duplicate them.
|
||||||
|
{
|
||||||
|
aws_bedrock: %w[claude-instant-1 claude-2],
|
||||||
|
anthropic: %w[claude-instant-1 claude-2],
|
||||||
|
vllm: %w[
|
||||||
|
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
StableBeluga2
|
||||||
|
Upstage-Llama-2-*-instruct-v2
|
||||||
|
Llama2-*-chat-hf
|
||||||
|
Llama2-chat-hf
|
||||||
|
],
|
||||||
|
hugging_face: %w[
|
||||||
|
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
StableBeluga2
|
||||||
|
Upstage-Llama-2-*-instruct-v2
|
||||||
|
Llama2-*-chat-hf
|
||||||
|
Llama2-chat-hf
|
||||||
|
],
|
||||||
|
open_ai: %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k gpt-4-turbo],
|
||||||
|
google: %w[gemini-pro],
|
||||||
|
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
|
||||||
|
end
|
||||||
|
|
||||||
yield(@canned_response)
|
def with_prepared_responses(responses)
|
||||||
ensure
|
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
||||||
# Don't leak prepared response if there's an exception.
|
|
||||||
@canned_response = nil
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.proxy(model_name)
|
yield(@canned_response)
|
||||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
ensure
|
||||||
|
# Don't leak prepared response if there's an exception.
|
||||||
|
@canned_response = nil
|
||||||
|
end
|
||||||
|
|
||||||
return new(dialect_klass, @canned_response, model_name) if @canned_response
|
def proxy(model_name)
|
||||||
|
provider_and_model_name = model_name.split(":")
|
||||||
|
|
||||||
gateway =
|
provider_name = provider_and_model_name.first
|
||||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
|
model_name_without_prov = provider_and_model_name[1..].join
|
||||||
model_name,
|
|
||||||
dialect_klass.tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
new(dialect_klass, gateway, model_name)
|
dialect_klass =
|
||||||
|
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
|
||||||
|
|
||||||
|
return new(dialect_klass, @canned_response, model_name) if @canned_response
|
||||||
|
|
||||||
|
gateway =
|
||||||
|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
||||||
|
provider_name,
|
||||||
|
model_name_without_prov,
|
||||||
|
).new(model_name_without_prov, dialect_klass.tokenizer)
|
||||||
|
|
||||||
|
new(dialect_klass, gateway, model_name_without_prov)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(dialect_klass, gateway, model_name)
|
def initialize(dialect_klass, gateway, model_name)
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Configuration
|
||||||
|
class LlmDependencyValidator
|
||||||
|
def initialize(opts = {})
|
||||||
|
@opts = opts
|
||||||
|
end
|
||||||
|
|
||||||
|
def valid_value?(val)
|
||||||
|
return true if val == "f"
|
||||||
|
|
||||||
|
SiteSetting.public_send(llm_dependency_setting_name).present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def error_message
|
||||||
|
I18n.t("discourse_ai.llm.configuration.set_llm_first", setting: llm_dependency_setting_name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def llm_dependency_setting_name
|
||||||
|
if @opts[:name] == :ai_embeddings_semantic_search_enabled
|
||||||
|
:ai_embeddings_semantic_search_hyde_model
|
||||||
|
else
|
||||||
|
:ai_helper_model
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,25 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "enum_site_setting"
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Configuration
|
||||||
|
class LlmEnumerator < ::EnumSiteSetting
|
||||||
|
def self.valid_value?(val)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.values
|
||||||
|
@values ||=
|
||||||
|
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
|
||||||
|
endpoint =
|
||||||
|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
|
||||||
|
|
||||||
|
models.map do |model_name|
|
||||||
|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,77 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Configuration
|
||||||
|
class LlmValidator
|
||||||
|
def initialize(opts = {})
|
||||||
|
@opts = opts
|
||||||
|
end
|
||||||
|
|
||||||
|
def valid_value?(val)
|
||||||
|
if val == ""
|
||||||
|
@parent_enabled = SiteSetting.public_send(parent_module_name)
|
||||||
|
return !@parent_enabled
|
||||||
|
end
|
||||||
|
|
||||||
|
provider_and_model_name = val.split(":")
|
||||||
|
|
||||||
|
provider_name = provider_and_model_name.first
|
||||||
|
model_name_without_prov = provider_and_model_name[1..].join
|
||||||
|
|
||||||
|
endpoint =
|
||||||
|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
||||||
|
provider_name,
|
||||||
|
model_name_without_prov,
|
||||||
|
)
|
||||||
|
|
||||||
|
return false if endpoint.nil?
|
||||||
|
|
||||||
|
if !endpoint.correctly_configured?(model_name_without_prov)
|
||||||
|
@endpoint = endpoint
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
if !can_talk_to_model?(val)
|
||||||
|
@unreachable = true
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def error_message
|
||||||
|
if @parent_enabled
|
||||||
|
return(
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.llm.configuration.disable_module_first",
|
||||||
|
setting: parent_module_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
return(I18n.t("discourse_ai.llm.configuration.model_unreachable")) if @unreachable
|
||||||
|
|
||||||
|
@endpoint&.configuration_hint
|
||||||
|
end
|
||||||
|
|
||||||
|
def parent_module_name
|
||||||
|
if @opts[:name] == :ai_embeddings_semantic_search_hyde_model
|
||||||
|
:ai_embeddings_semantic_search_enabled
|
||||||
|
else
|
||||||
|
:composer_ai_helper_enabled
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def can_talk_to_model?(model_name)
|
||||||
|
DiscourseAi::Completions::Llm
|
||||||
|
.proxy(model_name)
|
||||||
|
.generate("How much is 1 + 1?", user: nil)
|
||||||
|
.present?
|
||||||
|
rescue StandardError
|
||||||
|
false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -5,22 +5,45 @@ module DiscourseAi
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
foldable_models = [
|
foldable_models = [
|
||||||
Models::OpenAi.new("gpt-4", max_tokens: 8192),
|
Models::OpenAi.new("open_ai:gpt-4", max_tokens: 8192),
|
||||||
Models::OpenAi.new("gpt-4-32k", max_tokens: 32_768),
|
Models::OpenAi.new("open_ai:gpt-4-32k", max_tokens: 32_768),
|
||||||
Models::OpenAi.new("gpt-4-0125-preview", max_tokens: 100_000),
|
Models::OpenAi.new("open_ai:gpt-4-0125-preview", max_tokens: 100_000),
|
||||||
Models::OpenAi.new("gpt-3.5-turbo", max_tokens: 4096),
|
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
|
||||||
Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384),
|
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
|
||||||
Models::Anthropic.new("claude-2", max_tokens: 200_000),
|
Models::Llama2.new(
|
||||||
Models::Anthropic.new("claude-instant-1", max_tokens: 100_000),
|
"hugging_face:Llama2-chat-hf",
|
||||||
Models::Llama2.new("Llama2-chat-hf", max_tokens: SiteSetting.ai_hugging_face_token_limit),
|
|
||||||
Models::Llama2FineTunedOrcaStyle.new(
|
|
||||||
"StableBeluga2",
|
|
||||||
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
||||||
),
|
),
|
||||||
Models::Gemini.new("gemini-pro", max_tokens: 32_768),
|
Models::Llama2FineTunedOrcaStyle.new(
|
||||||
Models::Mixtral.new("mistralai/Mixtral-8x7B-Instruct-v0.1", max_tokens: 32_000),
|
"hugging_face:StableBeluga2",
|
||||||
|
max_tokens: SiteSetting.ai_hugging_face_token_limit,
|
||||||
|
),
|
||||||
|
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
claude_prov = "anthropic"
|
||||||
|
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||||
|
claude_prov = "aws_bedrock"
|
||||||
|
end
|
||||||
|
|
||||||
|
foldable_models << Models::Anthropic.new("#{claude_prov}:claude-2", max_tokens: 200_000)
|
||||||
|
foldable_models << Models::Anthropic.new(
|
||||||
|
"#{claude_prov}:claude-instant-1",
|
||||||
|
max_tokens: 100_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixtral_prov = "hugging_face"
|
||||||
|
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
|
||||||
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
)
|
||||||
|
mixtral_prov = "vllm"
|
||||||
|
end
|
||||||
|
|
||||||
|
foldable_models << Models::Mixtral.new(
|
||||||
|
"#{mixtral_prov}:mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
max_tokens: 32_000,
|
||||||
|
)
|
||||||
|
|
||||||
foldable_models.each do |model|
|
foldable_models.each do |model|
|
||||||
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
plugin.register_summarization_strategy(Strategies::FoldContent.new(model))
|
||||||
end
|
end
|
||||||
|
|
|
@ -4,8 +4,8 @@ module DiscourseAi
|
||||||
module Summarization
|
module Summarization
|
||||||
module Models
|
module Models
|
||||||
class Base
|
class Base
|
||||||
def initialize(model, max_tokens:)
|
def initialize(model_name, max_tokens:)
|
||||||
@model = model
|
@model_name = model_name
|
||||||
@max_tokens = max_tokens
|
@max_tokens = max_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -25,7 +25,11 @@ module DiscourseAi
|
||||||
max_tokens - reserved_tokens
|
max_tokens - reserved_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :model, :max_tokens
|
def model
|
||||||
|
model_name.split(":").last
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :model_name, :max_tokens
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ module DiscourseAi
|
||||||
def summarize(content, user, &on_partial_blk)
|
def summarize(content, user, &on_partial_blk)
|
||||||
opts = content.except(:contents)
|
opts = content.except(:contents)
|
||||||
|
|
||||||
llm = DiscourseAi::Completions::Llm.proxy(completion_model.model)
|
llm = DiscourseAi::Completions::Llm.proxy(completion_model.model_name)
|
||||||
|
|
||||||
initial_chunks =
|
initial_chunks =
|
||||||
rebalance_chunks(
|
rebalance_chunks(
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
RSpec.describe Jobs::StreamPostHelper do
|
RSpec.describe Jobs::StreamPostHelper do
|
||||||
subject(:job) { described_class.new }
|
subject(:job) { described_class.new }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_helper_model = "fake:fake" }
|
||||||
|
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
fab!(:post) do
|
fab!(:post) do
|
||||||
|
|
|
@ -5,7 +5,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
canned_response,
|
canned_response,
|
||||||
"Upstage-Llama-2-*-instruct-v2",
|
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
|
|
||||||
describe ".proxy" do
|
describe ".proxy" do
|
||||||
it "raises an exception when we can't proxy the model" do
|
it "raises an exception when we can't proxy the model" do
|
||||||
fake_model = "unknown_v2"
|
fake_model = "unknown:unknown_v2"
|
||||||
|
|
||||||
expect { described_class.proxy(fake_model) }.to(
|
expect { described_class.proxy(fake_model) }.to(
|
||||||
raise_error(DiscourseAi::Completions::Llm::UNKNOWN_MODEL),
|
raise_error(DiscourseAi::Completions::Llm::UNKNOWN_MODEL),
|
||||||
|
@ -27,7 +27,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
|
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:llm) { described_class.proxy("fake") }
|
let(:llm) { described_class.proxy("fake:fake") }
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
DiscourseAi::Completions::Prompt.new(
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
|
|
@ -6,7 +6,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
let(:prompts) { ["a pink cow", "a red cow"] }
|
let(:prompts) { ["a pink cow", "a red cow"] }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
|
|
@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||||
subject(:search) { described_class.new({ query: "some search term" }) }
|
subject(:search) { described_class.new({ query: "some search term" }) }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||||
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
|
subject(:tool) { described_class.new({ prompts: prompts, seeds: [99, 32] }) }
|
||||||
|
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
|
|
|
@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||||
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
|
subject(:tool) { described_class.new({ topic_id: topic_with_tags.id }) }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||||
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
fab!(:category) { Fabricate(:category, parent_category: parent_category, name: "amazing-cat") }
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,10 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
before { SearchIndexer.enable }
|
before { SearchIndexer.enable }
|
||||||
after { SearchIndexer.disable }
|
after { SearchIndexer.disable }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_openai_api_key = "asd" }
|
||||||
|
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
fab!(:admin)
|
fab!(:admin)
|
||||||
|
@ -65,6 +67,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
after { DiscourseAi::Embeddings::SemanticSearch.clear_cache_for(query) }
|
||||||
|
|
||||||
it "supports semantic search when enabled" do
|
it "supports semantic search when enabled" do
|
||||||
|
SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake"
|
||||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
let(:prompt) { CompletionPrompt.find_by(id: mode) }
|
let(:prompt) { CompletionPrompt.find_by(id: mode) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_helper_model = "fake:fake" }
|
||||||
|
|
||||||
let(:english_text) { <<~STRING }
|
let(:english_text) { <<~STRING }
|
||||||
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||||
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
RSpec.describe DiscourseAi::AiHelper::ChatThreadTitler do
|
RSpec.describe DiscourseAi::AiHelper::ChatThreadTitler do
|
||||||
subject(:titler) { described_class.new(thread) }
|
subject(:titler) { described_class.new(thread) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_helper_model = "fake:fake" }
|
||||||
|
|
||||||
fab!(:thread) { Fabricate(:chat_thread) }
|
fab!(:thread) { Fabricate(:chat_thread) }
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ RSpec.describe DiscourseAi::AiHelper::Painter do
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
|
SiteSetting.ai_helper_model = "fake:fake"
|
||||||
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
SiteSetting.ai_stability_api_url = "https://api.stability.dev"
|
||||||
SiteSetting.ai_stability_api_key = "abc"
|
SiteSetting.ai_stability_api_key = "abc"
|
||||||
SiteSetting.ai_openai_api_key = "abc"
|
SiteSetting.ai_openai_api_key = "abc"
|
||||||
|
|
|
@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "fake:fake",
|
||||||
hide_topic: true,
|
hide_topic: true,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -20,25 +20,11 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
expect(post.topic.reload.visible).to eq(true)
|
expect(post.topic.reload.visible).to eq(true)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can hide topics on triage with claude" do
|
it "can hide topics on triage" do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "claude-2",
|
model: "fake:fake",
|
||||||
hide_topic: true,
|
|
||||||
system_prompt: "test %%POST%%",
|
|
||||||
search_for_text: "bad",
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
expect(post.topic.reload.visible).to eq(false)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "can hide topics on triage with claude" do
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
|
||||||
triage(
|
|
||||||
post: post,
|
|
||||||
model: "gpt-4",
|
|
||||||
hide_topic: true,
|
hide_topic: true,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -54,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "fake:fake",
|
||||||
category_id: category.id,
|
category_id: category.id,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -69,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "fake:fake",
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
canned_reply: "test canned reply 123",
|
canned_reply: "test canned reply 123",
|
||||||
|
|
|
@ -22,7 +22,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: ["fake@discourse.com"],
|
receivers: ["fake@discourse.com"],
|
||||||
title: "test report %DATE%",
|
title: "test report %DATE%",
|
||||||
model: "gpt-4",
|
model: "fake:fake",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
@ -48,7 +48,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: [receiver.username],
|
receivers: [receiver.username],
|
||||||
title: "test report",
|
title: "test report",
|
||||||
model: "gpt-4",
|
model: "fake:fake",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
|
|
@ -7,6 +7,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
let(:query) { "test_query" }
|
let(:query) { "test_query" }
|
||||||
let(:subject) { described_class.new(Guardian.new(user)) }
|
let(:subject) { described_class.new(Guardian.new(user)) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake" }
|
||||||
|
|
||||||
describe "#search_for_topics" do
|
describe "#search_for_topics" do
|
||||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::Summarization::Strategies::FoldContent do
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:model) do
|
let(:model) do
|
||||||
DiscourseAi::Summarization::Models::OpenAi.new("gpt-4", max_tokens: model_tokens)
|
DiscourseAi::Summarization::Models::OpenAi.new("fake:fake", max_tokens: model_tokens)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:content) { { contents: [{ poster: "asd", id: 1, text: summarize_text }] } }
|
let(:content) { { contents: [{ poster: "asd", id: 1, text: summarize_text }] } }
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||||
|
before { SiteSetting.ai_helper_model = "fake:fake" }
|
||||||
|
|
||||||
describe "#suggest" do
|
describe "#suggest" do
|
||||||
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
|
let(:text_to_proofread) { "The rain in spain stays mainly in the plane." }
|
||||||
let(:proofread_text) { "The rain in Spain, stays mainly in the Plane." }
|
let(:proofread_text) { "The rain in Spain, stays mainly in the Plane." }
|
||||||
|
@ -90,7 +92,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||||
A user wrote this</input>
|
A user wrote this</input>
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do |spy|
|
DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do
|
||||||
post "/discourse-ai/ai-helper/suggest",
|
post "/discourse-ai/ai-helper/suggest",
|
||||||
params: {
|
params: {
|
||||||
mode: CompletionPrompt::CUSTOM_PROMPT,
|
mode: CompletionPrompt::CUSTOM_PROMPT,
|
||||||
|
@ -101,7 +103,6 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
|
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
|
||||||
expect(response.parsed_body["diff"]).to eq(expected_diff)
|
expect(response.parsed_body["diff"]).to eq(expected_diff)
|
||||||
expect(spy.prompt.translate.last[:content]).to eq(expected_input)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,6 +6,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
|
||||||
|
|
||||||
before do
|
before do
|
||||||
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
||||||
|
SiteSetting.ai_helper_model = "fake:fake"
|
||||||
SiteSetting.composer_ai_helper_enabled = true
|
SiteSetting.composer_ai_helper_enabled = true
|
||||||
sign_in(user)
|
sign_in(user)
|
||||||
end
|
end
|
||||||
|
|
|
@ -28,6 +28,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do
|
||||||
|
|
||||||
before do
|
before do
|
||||||
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
||||||
|
SiteSetting.ai_helper_model = "fake:fake"
|
||||||
SiteSetting.composer_ai_helper_enabled = true
|
SiteSetting.composer_ai_helper_enabled = true
|
||||||
sign_in(user)
|
sign_in(user)
|
||||||
end
|
end
|
||||||
|
|
|
@ -38,6 +38,7 @@ RSpec.describe "AI Post helper", type: :system, js: true do
|
||||||
|
|
||||||
before do
|
before do
|
||||||
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
|
||||||
|
SiteSetting.ai_helper_model = "fake:fake"
|
||||||
SiteSetting.composer_ai_helper_enabled = true
|
SiteSetting.composer_ai_helper_enabled = true
|
||||||
sign_in(user)
|
sign_in(user)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue