DEV: Remove old code now that features rely on LlmModels. (#729)
* DEV: Remove old code now that features rely on LlmModels. * Hide old settings and migrate persona llm overrides * Remove shadowing special URL + seeding code. Use srv:// prefix instead.
This commit is contained in:
parent
73a2b15e91
commit
bed044448c
|
@ -110,9 +110,7 @@ module DiscourseAi
|
|||
)
|
||||
|
||||
provider = updating ? updating.provider : permitted[:provider]
|
||||
permit_url =
|
||||
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
|
||||
provider != LlmModel::BEDROCK_PROVIDER_NAME
|
||||
permit_url = provider != LlmModel::BEDROCK_PROVIDER_NAME
|
||||
|
||||
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
|
||||
|
||||
|
|
|
@ -2,44 +2,10 @@
|
|||
|
||||
class LlmModel < ActiveRecord::Base
|
||||
FIRST_BOT_USER_ID = -1200
|
||||
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
|
||||
BEDROCK_PROVIDER_NAME = "aws_bedrock"
|
||||
|
||||
belongs_to :user
|
||||
|
||||
validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] }, if: :url_changed?
|
||||
|
||||
def self.seed_srv_backed_model
|
||||
srv = SiteSetting.ai_vllm_endpoint_srv
|
||||
srv_model = find_by(url: RESERVED_VLLM_SRV_URL)
|
||||
|
||||
if srv.present?
|
||||
if srv_model.present?
|
||||
current_key = SiteSetting.ai_vllm_api_key
|
||||
srv_model.update!(api_key: current_key) if current_key != srv_model.api_key
|
||||
else
|
||||
record =
|
||||
new(
|
||||
display_name: "vLLM SRV LLM",
|
||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
provider: "vllm",
|
||||
tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer",
|
||||
url: RESERVED_VLLM_SRV_URL,
|
||||
max_prompt_tokens: 8000,
|
||||
api_key: SiteSetting.ai_vllm_api_key,
|
||||
)
|
||||
|
||||
record.save(validate: false) # Ignore reserved URL validation
|
||||
end
|
||||
else
|
||||
# Clean up companion users
|
||||
srv_model&.enabled_chat_bot = false
|
||||
srv_model&.toggle_companion_user
|
||||
|
||||
srv_model&.destroy!
|
||||
end
|
||||
end
|
||||
|
||||
def self.provider_params
|
||||
{
|
||||
aws_bedrock: {
|
||||
|
@ -54,7 +20,7 @@ class LlmModel < ActiveRecord::Base
|
|||
end
|
||||
|
||||
def to_llm
|
||||
DiscourseAi::Completions::Llm.proxy_from_obj(self)
|
||||
DiscourseAi::Completions::Llm.proxy("custom:#{id}")
|
||||
end
|
||||
|
||||
def toggle_companion_user
|
||||
|
|
|
@ -19,6 +19,6 @@ class LlmModelSerializer < ApplicationSerializer
|
|||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||
|
||||
def shadowed_by_srv
|
||||
object.url == LlmModel::RESERVED_VLLM_SRV_URL
|
||||
object.url.to_s.starts_with?("srv://")
|
||||
end
|
||||
end
|
||||
|
|
|
@ -60,17 +60,9 @@ export default class AiLlmEditorForm extends Component {
|
|||
return this.testRunning || this.testResult !== null;
|
||||
}
|
||||
|
||||
get displaySRVWarning() {
|
||||
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
|
||||
}
|
||||
|
||||
get canEditURL() {
|
||||
// Explicitly false.
|
||||
if (this.metaProviderParams.url_editable === false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
|
||||
return this.metaProviderParams.url_editable !== false;
|
||||
}
|
||||
|
||||
@computed("args.model.provider")
|
||||
|
@ -174,12 +166,6 @@ export default class AiLlmEditorForm extends Component {
|
|||
}
|
||||
|
||||
<template>
|
||||
{{#if this.displaySRVWarning}}
|
||||
<div class="alert alert-info">
|
||||
{{icon "exclamation-circle"}}
|
||||
{{I18n.t "discourse_ai.llms.srv_warning"}}
|
||||
</div>
|
||||
{{/if}}
|
||||
<form class="form-horizontal ai-llm-editor">
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||
|
|
|
@ -237,7 +237,6 @@ en:
|
|||
confirm_delete: Are you sure you want to delete this model?
|
||||
delete: Delete
|
||||
|
||||
srv_warning: This LLM points to an SRV record, and its URL is not editable. You have to update the hidden "ai_vllm_endpoint_srv" setting instead.
|
||||
preconfigured_llms: "Select your LLM"
|
||||
preconfigured:
|
||||
none: "Configure manually..."
|
||||
|
|
|
@ -96,21 +96,35 @@ discourse_ai:
|
|||
- opennsfw2
|
||||
- nsfw_detector
|
||||
|
||||
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions"
|
||||
ai_openai_gpt35_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_gpt35_16k_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_gpt4o_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_gpt4_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_gpt4_32k_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_gpt4_turbo_url:
|
||||
default: "https://api.openai.com/v1/chat/completions"
|
||||
hidden: true
|
||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||
ai_openai_organization: ""
|
||||
ai_openai_organization:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_openai_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_anthropic_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_anthropic_native_tool_call_models:
|
||||
type: list
|
||||
list_type: compact
|
||||
|
@ -123,7 +137,7 @@ discourse_ai:
|
|||
- claude-3-5-sonnet
|
||||
ai_cohere_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_stability_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
|
@ -140,13 +154,16 @@ discourse_ai:
|
|||
- "stable-diffusion-v1-5"
|
||||
ai_hugging_face_api_url:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_hugging_face_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_hugging_face_token_limit:
|
||||
default: 4096
|
||||
hidden: true
|
||||
ai_hugging_face_model_display_name:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_hugging_face_tei_endpoint:
|
||||
default: ""
|
||||
ai_hugging_face_tei_endpoint_srv:
|
||||
|
@ -167,11 +184,13 @@ discourse_ai:
|
|||
ai_bedrock_access_key_id:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_bedrock_secret_access_key:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_bedrock_region:
|
||||
default: "us-east-1"
|
||||
hidden: true
|
||||
ai_cloudflare_workers_account_id:
|
||||
default: ""
|
||||
secret: true
|
||||
|
@ -180,13 +199,16 @@ discourse_ai:
|
|||
secret: true
|
||||
ai_gemini_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
hidden: true
|
||||
ai_vllm_endpoint:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_vllm_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_vllm_api_key: ""
|
||||
ai_vllm_api_key:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_llava_endpoint:
|
||||
default: ""
|
||||
hidden: true
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
begin
|
||||
LlmModel.seed_srv_backed_model
|
||||
rescue PG::UndefinedColumn => e
|
||||
# If this code runs before migrations, an attribute might be missing.
|
||||
Rails.logger.warn("Failed to seed SRV-Backed LLM: #{e.meesage}")
|
||||
end
|
|
@ -25,8 +25,9 @@ class MigrateVisionLlms < ActiveRecord::Migration[7.1]
|
|||
).first
|
||||
|
||||
if current_value && current_value != "llava"
|
||||
model_name = current_value.split(":").last
|
||||
llm_model =
|
||||
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: current_value).first
|
||||
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: model_name).first
|
||||
|
||||
if llm_model
|
||||
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
class MigratePersonaLlmOverride < ActiveRecord::Migration[7.1]
|
||||
def up
|
||||
fields_to_update = DB.query(<<~SQL)
|
||||
SELECT id, default_llm
|
||||
FROM ai_personas
|
||||
WHERE default_llm IS NOT NULL
|
||||
SQL
|
||||
|
||||
return if fields_to_update.empty?
|
||||
|
||||
updated_fields =
|
||||
fields_to_update
|
||||
.map do |field|
|
||||
llm_model_id = matching_llm_model(field.default_llm)
|
||||
|
||||
"(#{field.id}, 'custom:#{llm_model_id}')" if llm_model_id
|
||||
end
|
||||
.compact
|
||||
|
||||
return if updated_fields.empty?
|
||||
|
||||
DB.exec(<<~SQL)
|
||||
UPDATE ai_personas
|
||||
SET default_llm = new_fields.new_default_llm
|
||||
FROM (VALUES #{updated_fields.join(", ")}) AS new_fields(id, new_default_llm)
|
||||
WHERE new_fields.id::bigint = ai_personas.id
|
||||
SQL
|
||||
end
|
||||
|
||||
def matching_llm_model(model)
|
||||
provider = model.split(":").first
|
||||
model_name = model.split(":").last
|
||||
|
||||
return if provider == "custom"
|
||||
|
||||
DB.query_single(
|
||||
"SELECT id FROM llm_models WHERE name = :name AND provider = :provider",
|
||||
{ name: model_name, provider: provider },
|
||||
).first
|
||||
end
|
||||
|
||||
def down
|
||||
raise ActiveRecord::IrreversibleMigration
|
||||
end
|
||||
end
|
|
@ -5,19 +5,15 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class ChatGpt < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
model_name.starts_with?("gpt-")
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "open_ai" || model_provider == "azure"
|
||||
end
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
llm_model.provider == "open_ai" || llm_model.provider == "azure"
|
||||
end
|
||||
|
||||
def translate
|
||||
|
@ -30,19 +26,17 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
# provide a buffer of 120 tokens - our function counting is not
|
||||
# 100% accurate and getting numbers to align exactly is very hard
|
||||
buffer = (opts[:max_tokens] || 2500) + 50
|
||||
|
||||
if tools.present?
|
||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
|
||||
@function_size ||= llm_model.tokenizer_class.size(tools.to_json.to_s)
|
||||
buffer += @function_size
|
||||
end
|
||||
|
||||
model_max_tokens - buffer
|
||||
llm_model.max_prompt_tokens - buffer
|
||||
end
|
||||
|
||||
private
|
||||
|
@ -105,24 +99,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def model_max_tokens
|
||||
case model_name
|
||||
when "gpt-3.5-turbo-16k"
|
||||
16_384
|
||||
when "gpt-4"
|
||||
8192
|
||||
when "gpt-4-32k"
|
||||
32_768
|
||||
when "gpt-4-turbo"
|
||||
131_072
|
||||
when "gpt-4o"
|
||||
131_072
|
||||
else
|
||||
8192
|
||||
end
|
||||
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Claude < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
model_name.start_with?("claude") || model_name.start_with?("anthropic")
|
||||
def can_translate?(provider_name)
|
||||
provider_name == "anthropic" || provider_name == "aws_bedrock"
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -26,10 +26,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = super
|
||||
|
||||
|
@ -61,14 +57,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
# Longer term it will have over 1 million
|
||||
200_000 # Claude-3 has a 200k context window for now
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
|
||||
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name)
|
||||
end
|
||||
|
||||
private
|
||||
|
|
|
@ -6,18 +6,12 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Dialects
|
||||
class Command < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
||||
end
|
||||
def self.can_translate?(model_provider)
|
||||
model_provider == "cohere"
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = super
|
||||
|
||||
|
@ -68,20 +62,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
case model_name
|
||||
when "command-light"
|
||||
4096
|
||||
when "command"
|
||||
8192
|
||||
when "command-r"
|
||||
131_072
|
||||
when "command-r-plus"
|
||||
131_072
|
||||
else
|
||||
8192
|
||||
end
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
|
@ -99,7 +80,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
|
|
|
@ -5,7 +5,7 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Dialect
|
||||
class << self
|
||||
def can_translate?(_model_name)
|
||||
def can_translate?(model_provider)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
|
@ -19,7 +19,7 @@ module DiscourseAi
|
|||
]
|
||||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
def dialect_for(model_provider)
|
||||
dialects = []
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
|
@ -28,26 +28,21 @@ module DiscourseAi
|
|||
|
||||
dialects = dialects.concat(all_dialects)
|
||||
|
||||
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
||||
dialect = dialects.find { |d| d.can_translate?(model_provider) }
|
||||
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
||||
|
||||
dialect
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
|
||||
def initialize(generic_prompt, llm_model, opts: {})
|
||||
@prompt = generic_prompt
|
||||
@model_name = model_name
|
||||
@opts = opts
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def can_end_with_assistant_msg?
|
||||
false
|
||||
end
|
||||
|
@ -57,7 +52,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vision_support?
|
||||
llm_model&.vision_enabled?
|
||||
llm_model.vision_enabled?
|
||||
end
|
||||
|
||||
def tools
|
||||
|
@ -88,12 +83,12 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
attr_reader :model_name, :opts, :llm_model
|
||||
attr_reader :opts, :llm_model
|
||||
|
||||
def trim_messages(messages)
|
||||
prompt_limit = max_prompt_tokens
|
||||
current_token_count = 0
|
||||
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
||||
message_step_size = (prompt_limit / 25).to_i * -1
|
||||
|
||||
trimmed_messages = []
|
||||
|
||||
|
@ -157,7 +152,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(msg)
|
||||
self.tokenizer.size(msg[:content].to_s)
|
||||
llm_model.tokenizer_class.size(msg[:content].to_s)
|
||||
end
|
||||
|
||||
def tools_dialect
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Gemini < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "google"
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -14,10 +14,6 @@ module DiscourseAi
|
|||
true
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
# Gemini complains if we don't alternate model/user roles.
|
||||
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
||||
|
@ -74,24 +70,17 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
if model_name.start_with?("gemini-1.5")
|
||||
# technically we support 1 million tokens, but we're being conservative
|
||||
800_000
|
||||
else
|
||||
16_384 # 50% of model tokens
|
||||
end
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def beta_api?
|
||||
@beta_api ||= model_name.start_with?("gemini-1.5")
|
||||
@beta_api ||= llm_model.name.start_with?("gemini-1.5")
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Anthropic < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "anthropic"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_anthropic_api_key]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_anthropic_api_key.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Anthropic - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "anthropic"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -29,7 +15,7 @@ module DiscourseAi
|
|||
|
||||
def default_options(dialect)
|
||||
mapped_model =
|
||||
case model
|
||||
case llm_model.name
|
||||
when "claude-2"
|
||||
"claude-2.1"
|
||||
when "claude-instant-1"
|
||||
|
@ -43,7 +29,7 @@ module DiscourseAi
|
|||
when "claude-3-5-sonnet"
|
||||
"claude-3-5-sonnet-20240620"
|
||||
else
|
||||
model
|
||||
llm_model.name
|
||||
end
|
||||
|
||||
options = { model: mapped_model, max_tokens: 3_000 }
|
||||
|
@ -74,9 +60,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def model_uri
|
||||
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
|
||||
|
||||
URI(url)
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -94,7 +78,7 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = {
|
||||
"anthropic-version" => "2023-06-01",
|
||||
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
|
||||
"x-api-key" => llm_model.api_key,
|
||||
"content-type" => "application/json",
|
||||
}
|
||||
|
||||
|
|
|
@ -6,24 +6,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class AwsBedrock < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "aws_bedrock"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model)
|
||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||
SiteSetting.ai_bedrock_region.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"AWS Bedrock - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "aws_bedrock"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -62,37 +46,28 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def model_uri
|
||||
if llm_model
|
||||
region = llm_model.lookup_custom_param("region")
|
||||
region = llm_model.lookup_custom_param("region")
|
||||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
|
||||
else
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
#
|
||||
# FYI there is a 2.0 version of Claude, very little need to support it given
|
||||
# haiku/sonnet are better fits anyway, we map to claude-2.1
|
||||
bedrock_model_id =
|
||||
case model
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
when "claude-3-haiku"
|
||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||
when "claude-3-sonnet"
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
when "claude-instant-1"
|
||||
"anthropic.claude-instant-v1"
|
||||
when "claude-3-opus"
|
||||
"anthropic.claude-3-opus-20240229-v1:0"
|
||||
when "claude-3-5-sonnet"
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
else
|
||||
model
|
||||
end
|
||||
bedrock_model_id =
|
||||
case llm_model.name
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
when "claude-3-haiku"
|
||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||
when "claude-3-sonnet"
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
when "claude-instant-1"
|
||||
"anthropic.claude-instant-v1"
|
||||
when "claude-3-opus"
|
||||
"anthropic.claude-3-opus-20240229-v1:0"
|
||||
when "claude-3-5-sonnet"
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
else
|
||||
llm_model.name
|
||||
end
|
||||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
end
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
|
||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||
|
||||
|
@ -114,11 +89,9 @@ module DiscourseAi
|
|||
|
||||
signer =
|
||||
Aws::Sigv4::Signer.new(
|
||||
access_key_id:
|
||||
llm_model&.lookup_custom_param("access_key_id") ||
|
||||
SiteSetting.ai_bedrock_access_key_id,
|
||||
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
|
||||
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
|
||||
access_key_id: llm_model.lookup_custom_param("access_key_id"),
|
||||
region: llm_model.lookup_custom_param("region"),
|
||||
secret_access_key: llm_model.api_key,
|
||||
service: "bedrock",
|
||||
)
|
||||
|
||||
|
|
|
@ -30,39 +30,12 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
settings = dependant_setting_names
|
||||
I18n.t(
|
||||
"discourse_ai.llm.endpoints.configuration_hint",
|
||||
settings: settings.join(", "),
|
||||
count: settings.length,
|
||||
)
|
||||
end
|
||||
|
||||
def display_name(model_name)
|
||||
to_display = endpoint_name(model_name)
|
||||
|
||||
return to_display if correctly_configured?(model_name)
|
||||
|
||||
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def endpoint_name(_model_name)
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def can_contact?(_endpoint_name)
|
||||
def can_contact?(_model_provider)
|
||||
raise NotImplementedError
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(model_name, tokenizer, llm_model: nil)
|
||||
@model = model_name
|
||||
@tokenizer = tokenizer
|
||||
def initialize(llm_model)
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
|
@ -136,7 +109,7 @@ module DiscourseAi
|
|||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
feature_name: feature_name,
|
||||
language_model: self.class.endpoint_name(@model),
|
||||
language_model: llm_model.name,
|
||||
)
|
||||
|
||||
if !@streaming_mode
|
||||
|
@ -323,10 +296,14 @@ module DiscourseAi
|
|||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||
end
|
||||
|
||||
attr_reader :tokenizer, :model, :llm_model
|
||||
attr_reader :llm_model
|
||||
|
||||
protected
|
||||
|
||||
def tokenizer
|
||||
llm_model.tokenizer_class
|
||||
end
|
||||
|
||||
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
||||
def normalize_model_params(model_params)
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -6,10 +6,6 @@ module DiscourseAi
|
|||
class CannedResponse
|
||||
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
||||
|
||||
def self.can_contact?(_)
|
||||
Rails.env.test?
|
||||
end
|
||||
|
||||
def initialize(responses)
|
||||
@responses = responses
|
||||
@completions = 0
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Cohere < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "cohere"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_cohere_api_key]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_cohere_api_key.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Cohere - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "cohere"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -39,9 +25,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
|
||||
|
||||
URI(url)
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -59,7 +43,7 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = {
|
||||
"Content-Type" => "application/json",
|
||||
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
|
||||
"Authorization" => "Bearer #{llm_model.api_key}",
|
||||
}
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
|
|
|
@ -4,20 +4,6 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Fake < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "fake"
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
true
|
||||
end
|
||||
|
||||
def endpoint_name(_model_name)
|
||||
"Test - fake model"
|
||||
end
|
||||
end
|
||||
|
||||
STOCK_CONTENT = <<~TEXT
|
||||
# Discourse Markdown Styles Showcase
|
||||
|
||||
|
@ -75,6 +61,10 @@ module DiscourseAi
|
|||
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
|
||||
TEXT
|
||||
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "fake"
|
||||
end
|
||||
|
||||
def self.with_fake_content(content)
|
||||
@fake_content = content
|
||||
yield
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Gemini < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "google"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_gemini_api_key]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_gemini_api_key.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Google - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "google"
|
||||
end
|
||||
|
||||
def default_options
|
||||
|
@ -59,21 +45,8 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
if llm_model
|
||||
url = llm_model.url
|
||||
else
|
||||
mapped_model = model
|
||||
if model == "gemini-1.5-pro"
|
||||
mapped_model = "gemini-1.5-pro-latest"
|
||||
elsif model == "gemini-1.5-flash"
|
||||
mapped_model = "gemini-1.5-flash-latest"
|
||||
elsif model == "gemini-1.0-pro"
|
||||
mapped_model = "gemini-pro-latest"
|
||||
end
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
||||
end
|
||||
|
||||
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
||||
url = llm_model.url
|
||||
key = llm_model.api_key
|
||||
|
||||
if @streaming_mode
|
||||
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class HuggingFace < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "hugging_face"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_hugging_face_api_url]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_hugging_face_api_url.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Hugging Face - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "hugging_face"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{ model: model, temperature: 0.7 }
|
||||
{ model: llm_model.name, temperature: 0.7 }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -44,7 +30,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
|
@ -53,8 +39,7 @@ module DiscourseAi
|
|||
.merge(messages: prompt)
|
||||
.tap do |payload|
|
||||
if !payload[:max_tokens]
|
||||
token_limit =
|
||||
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
|
||||
token_limit = llm_model.max_prompt_tokens
|
||||
|
||||
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
||||
end
|
||||
|
@ -64,7 +49,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
|
||||
api_key = llm_model.api_key
|
||||
|
||||
headers =
|
||||
{ "Content-Type" => "application/json" }.tap do |h|
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Ollama < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "ollama"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_ollama_endpoint]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_ollama_endpoint.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Ollama - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "ollama"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{ max_tokens: 2000, model: model }
|
||||
{ max_tokens: 2000, model: llm_model.name }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -48,7 +34,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
|
|
|
@ -4,56 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class OpenAi < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "open_ai"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_openai_api_key
|
||||
ai_openai_gpt4o_url
|
||||
ai_openai_gpt4_32k_url
|
||||
ai_openai_gpt4_turbo_url
|
||||
ai_openai_gpt4_url
|
||||
ai_openai_gpt4_url
|
||||
ai_openai_gpt35_16k_url
|
||||
ai_openai_gpt35_url
|
||||
]
|
||||
end
|
||||
|
||||
def correctly_configured?(model_name)
|
||||
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
|
||||
end
|
||||
|
||||
def has_url?(model)
|
||||
url =
|
||||
if model.include?("gpt-4")
|
||||
if model.include?("32k")
|
||||
SiteSetting.ai_openai_gpt4_32k_url
|
||||
else
|
||||
if model.include?("1106") || model.include?("turbo")
|
||||
SiteSetting.ai_openai_gpt4_turbo_url
|
||||
elsif model.include?("gpt-4o")
|
||||
SiteSetting.ai_openai_gpt4o_url
|
||||
else
|
||||
SiteSetting.ai_openai_gpt4_url
|
||||
end
|
||||
end
|
||||
else
|
||||
if model.include?("16k")
|
||||
SiteSetting.ai_openai_gpt35_16k_url
|
||||
else
|
||||
SiteSetting.ai_openai_gpt35_url
|
||||
end
|
||||
end
|
||||
|
||||
url.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"OpenAI - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
%w[open_ai azure].include?(model_provider)
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -68,7 +20,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{ model: model }
|
||||
{ model: llm_model.name }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -78,28 +30,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
return URI(llm_model.url) if llm_model&.url
|
||||
|
||||
url =
|
||||
if model.include?("gpt-4")
|
||||
if model.include?("32k")
|
||||
SiteSetting.ai_openai_gpt4_32k_url
|
||||
else
|
||||
if model.include?("1106") || model.include?("turbo")
|
||||
SiteSetting.ai_openai_gpt4_turbo_url
|
||||
else
|
||||
SiteSetting.ai_openai_gpt4_url
|
||||
end
|
||||
end
|
||||
else
|
||||
if model.include?("16k")
|
||||
SiteSetting.ai_openai_gpt35_16k_url
|
||||
else
|
||||
SiteSetting.ai_openai_gpt35_url
|
||||
end
|
||||
end
|
||||
|
||||
URI(url)
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -110,7 +41,7 @@ module DiscourseAi
|
|||
|
||||
# Usage is not available in Azure yet.
|
||||
# We'll fallback to guess this using the tokenizer.
|
||||
payload[:stream_options] = { include_usage: true } if model_uri.host.exclude?("azure")
|
||||
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
||||
end
|
||||
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
|
@ -119,19 +50,16 @@ module DiscourseAi
|
|||
|
||||
def prepare_request(payload)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
api_key = llm_model.api_key
|
||||
|
||||
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
|
||||
|
||||
if model_uri.host.include?("azure")
|
||||
if llm_model.provider == "azure"
|
||||
headers["api-key"] = api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
org_id = llm_model.lookup_custom_param("organization")
|
||||
headers["OpenAI-Organization"] = org_id if org_id.present?
|
||||
end
|
||||
|
||||
org_id =
|
||||
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
|
||||
headers["OpenAI-Organization"] = org_id if org_id.present?
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Vllm < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name)
|
||||
endpoint_name == "vllm"
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"vLLM - #{model_name}"
|
||||
end
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "vllm"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{ max_tokens: 2000, model: model }
|
||||
{ max_tokens: 2000, model: llm_model.name }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -44,16 +30,13 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
if llm_model&.url && !llm_model&.url == LlmModel::RESERVED_VLLM_SRV_URL
|
||||
return URI(llm_model.url)
|
||||
end
|
||||
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
||||
if service.present?
|
||||
if llm_model.url.to_s.starts_with?("srv://")
|
||||
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
||||
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
||||
else
|
||||
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
|
||||
api_endpoint = llm_model.url
|
||||
end
|
||||
|
||||
@uri ||= URI(api_endpoint)
|
||||
end
|
||||
|
||||
|
|
|
@ -64,8 +64,8 @@ module DiscourseAi
|
|||
id: "open_ai",
|
||||
models: [
|
||||
{ name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" },
|
||||
{ name: "gpt-4o-mini", tokens: 131_072, display_name: "GPT-4 Omni Mini" },
|
||||
{ name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" },
|
||||
{ name: "gpt-3.5-turbo", tokens: 16_385, display_name: "GPT-3.5 Turbo" },
|
||||
],
|
||||
tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||
endpoint: "https://api.openai.com/v1/chat/completions",
|
||||
|
@ -89,41 +89,6 @@ module DiscourseAi
|
|||
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
|
||||
end
|
||||
|
||||
def models_by_provider
|
||||
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
|
||||
# However, since they use the same URL/key settings, there's no reason to duplicate them.
|
||||
@models_by_provider ||=
|
||||
{
|
||||
aws_bedrock: %w[
|
||||
claude-instant-1
|
||||
claude-2
|
||||
claude-3-haiku
|
||||
claude-3-sonnet
|
||||
claude-3-opus
|
||||
],
|
||||
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
|
||||
vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
|
||||
hugging_face: %w[
|
||||
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
mistralai/Mistral-7B-Instruct-v0.2
|
||||
],
|
||||
cohere: %w[command-light command command-r command-r-plus],
|
||||
open_ai: %w[
|
||||
gpt-3.5-turbo
|
||||
gpt-4
|
||||
gpt-3.5-turbo-16k
|
||||
gpt-4-32k
|
||||
gpt-4-turbo
|
||||
gpt-4-vision-preview
|
||||
gpt-4o
|
||||
],
|
||||
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
|
||||
}.tap do |h|
|
||||
h[:ollama] = ["mistral"] if Rails.env.development?
|
||||
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
|
||||
end
|
||||
end
|
||||
|
||||
def valid_provider_models
|
||||
return @valid_provider_models if defined?(@valid_provider_models)
|
||||
|
||||
|
@ -151,61 +116,38 @@ module DiscourseAi
|
|||
@prompts << prompt if @prompts
|
||||
end
|
||||
|
||||
def proxy(model_name)
|
||||
provider_and_model_name = model_name.split(":")
|
||||
provider_name = provider_and_model_name.first
|
||||
model_name_without_prov = provider_and_model_name[1..].join
|
||||
def proxy(model)
|
||||
llm_model =
|
||||
if model.is_a?(LlmModel)
|
||||
model
|
||||
else
|
||||
model_name_without_prov = model.split(":").last.to_i
|
||||
|
||||
# We are in the process of transitioning to always use objects here.
|
||||
# We'll live with this hack for a while.
|
||||
if provider_name == "custom"
|
||||
llm_model = LlmModel.find(model_name_without_prov)
|
||||
raise UNKNOWN_MODEL if !llm_model
|
||||
return proxy_from_obj(llm_model)
|
||||
end
|
||||
|
||||
dialect_klass =
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
|
||||
|
||||
if @canned_response
|
||||
if @canned_llm && @canned_llm != model_name
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||
LlmModel.find_by(id: model_name_without_prov)
|
||||
end
|
||||
|
||||
return new(dialect_klass, nil, model_name, gateway: @canned_response)
|
||||
end
|
||||
raise UNKNOWN_MODEL if llm_model.nil?
|
||||
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov)
|
||||
end
|
||||
|
||||
def proxy_from_obj(llm_model)
|
||||
provider_name = llm_model.provider
|
||||
model_name = llm_model.name
|
||||
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
||||
model_provider = llm_model.provider
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
|
||||
|
||||
if @canned_response
|
||||
if @canned_llm && @canned_llm != [provider_name, model_name].join(":")
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||
if @canned_llm && @canned_llm != model
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model}"
|
||||
end
|
||||
|
||||
return(
|
||||
new(dialect_klass, nil, model_name, gateway: @canned_response, llm_model: llm_model)
|
||||
)
|
||||
return new(dialect_klass, nil, llm_model, gateway: @canned_response)
|
||||
end
|
||||
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
|
||||
new(dialect_klass, gateway_klass, llm_model)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
|
||||
def initialize(dialect_klass, gateway_klass, llm_model, gateway: nil)
|
||||
@dialect_klass = dialect_klass
|
||||
@gateway_klass = gateway_klass
|
||||
@model_name = model_name
|
||||
@gateway = gateway
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
@ -264,9 +206,9 @@ module DiscourseAi
|
|||
|
||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||
|
||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
|
||||
dialect = dialect_klass.new(prompt, llm_model, opts: model_params)
|
||||
|
||||
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
|
||||
gateway = @gateway || gateway_klass.new(llm_model)
|
||||
gateway.perform_completion!(
|
||||
dialect,
|
||||
user,
|
||||
|
@ -277,16 +219,14 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
llm_model&.max_prompt_tokens ||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class ||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
|
||||
llm_model.tokenizer_class
|
||||
end
|
||||
|
||||
attr_reader :model_name, :llm_model
|
||||
attr_reader :llm_model
|
||||
|
||||
private
|
||||
|
||||
|
|
|
@ -15,19 +15,16 @@ module DiscourseAi
|
|||
return !@parent_enabled
|
||||
end
|
||||
|
||||
llm_model_id = val.split(":")&.last
|
||||
llm_model = LlmModel.find_by(id: llm_model_id)
|
||||
return false if llm_model.nil?
|
||||
|
||||
run_test(llm_model).tap { |result| @unreachable = result }
|
||||
rescue StandardError
|
||||
run_test(val).tap { |result| @unreachable = result }
|
||||
rescue StandardError => e
|
||||
raise e if Rails.env.test?
|
||||
@unreachable = true
|
||||
false
|
||||
end
|
||||
|
||||
def run_test(llm_model)
|
||||
def run_test(val)
|
||||
DiscourseAi::Completions::Llm
|
||||
.proxy_from_obj(llm_model)
|
||||
.proxy(val)
|
||||
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
|
||||
.present?
|
||||
end
|
||||
|
|
|
@ -80,8 +80,4 @@ after_initialize do
|
|||
nil
|
||||
end
|
||||
end
|
||||
|
||||
on(:site_setting_changed) do |name, _old_value, _new_value|
|
||||
LlmModel.seed_srv_backed_model if name == :ai_vllm_endpoint_srv || name == :ai_vllm_api_key
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,5 +5,67 @@ Fabricator(:llm_model) do
|
|||
name "gpt-4-turbo"
|
||||
provider "open_ai"
|
||||
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
api_key "123"
|
||||
url "https://api.openai.com/v1/chat/completions"
|
||||
max_prompt_tokens 131_072
|
||||
end
|
||||
|
||||
Fabricator(:anthropic_model, from: :llm_model) do
|
||||
display_name "Claude 3 Opus"
|
||||
name "claude-3-opus"
|
||||
max_prompt_tokens 200_000
|
||||
url "https://api.anthropic.com/v1/messages"
|
||||
tokenizer "DiscourseAi::Tokenizer::AnthropicTokenizer"
|
||||
provider "anthropic"
|
||||
end
|
||||
|
||||
Fabricator(:hf_model, from: :llm_model) do
|
||||
display_name "Llama 3.1"
|
||||
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||
max_prompt_tokens 64_000
|
||||
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||
url "https://test.dev/v1/chat/completions"
|
||||
provider "hugging_face"
|
||||
end
|
||||
|
||||
Fabricator(:vllm_model, from: :llm_model) do
|
||||
display_name "Llama 3.1 vLLM"
|
||||
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||
max_prompt_tokens 64_000
|
||||
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||
url "https://test.dev/v1/chat/completions"
|
||||
provider "vllm"
|
||||
end
|
||||
|
||||
Fabricator(:fake_model, from: :llm_model) do
|
||||
display_name "Fake model"
|
||||
name "fake"
|
||||
provider "fake"
|
||||
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
max_prompt_tokens 32_000
|
||||
end
|
||||
|
||||
Fabricator(:gemini_model, from: :llm_model) do
|
||||
display_name "Gemini"
|
||||
name "gemini-1.5-pro"
|
||||
provider "google"
|
||||
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
max_prompt_tokens 800_000
|
||||
url "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest"
|
||||
end
|
||||
|
||||
Fabricator(:bedrock_model, from: :anthropic_model) do
|
||||
url ""
|
||||
provider "aws_bedrock"
|
||||
api_key "asd-asd-asd"
|
||||
name "claude-3-sonnet"
|
||||
provider_params { { region: "us-east-1", access_key_id: "123456" } }
|
||||
end
|
||||
|
||||
Fabricator(:cohere_model, from: :llm_model) do
|
||||
display_name "Cohere Command R+"
|
||||
name "command-r-plus"
|
||||
provider "cohere"
|
||||
api_key "ABC"
|
||||
url "https://api.cohere.ai/v1/chat"
|
||||
end
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||
let(:model_name) { "gpt-4" }
|
||||
let(:context) { DialectContext.new(described_class, model_name) }
|
||||
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
|
||||
let(:context) { DialectContext.new(described_class, llm_model) }
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the ChatGPT format" do
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||
let :opus_dialect_klass do
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
|
||||
end
|
||||
|
||||
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
|
||||
|
||||
describe "#translate" do
|
||||
it "can insert OKs to make stuff interleve properly" do
|
||||
messages = [
|
||||
|
@ -17,7 +19,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expected_messages = [
|
||||
|
@ -62,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
tools: tools,
|
||||
)
|
||||
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
|
@ -114,7 +116,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
messages: messages,
|
||||
tools: tools,
|
||||
)
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DialectContext
|
||||
def initialize(dialect_klass, model_name)
|
||||
def initialize(dialect_klass, llm_model)
|
||||
@dialect_klass = dialect_klass
|
||||
@model_name = model_name
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
def dialect(prompt)
|
||||
@dialect_klass.new(prompt, @model_name)
|
||||
@dialect_klass.new(prompt, @llm_model)
|
||||
end
|
||||
|
||||
def prompt
|
||||
|
|
|
@ -13,6 +13,8 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||
fab!(:llm_model)
|
||||
|
||||
describe "#trim_messages" do
|
||||
let(:five_token_msg) { "This represents five tokens." }
|
||||
|
||||
|
@ -23,7 +25,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
|||
prompt.push(type: :tool, content: five_token_msg, id: 1)
|
||||
prompt.push(type: :user, content: five_token_msg)
|
||||
|
||||
dialect = TestDialect.new(prompt, "test")
|
||||
dialect = TestDialect.new(prompt, llm_model)
|
||||
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
|
||||
|
||||
trimmed = dialect.trim(prompt.messages)
|
||||
|
@ -37,7 +39,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
|||
prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens")
|
||||
prompt.push(type: :user, content: five_token_msg)
|
||||
|
||||
dialect = TestDialect.new(prompt, "test")
|
||||
dialect = TestDialect.new(prompt, llm_model)
|
||||
dialect.max_prompt_tokens = 15
|
||||
|
||||
trimmed = dialect.trim(prompt.messages)
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
let(:model_name) { "gemini-1.5-pro" }
|
||||
let(:context) { DialectContext.new(described_class, model_name) }
|
||||
fab!(:model) { Fabricate(:gemini_model) }
|
||||
let(:context) { DialectContext.new(described_class, model) }
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the Gemini format" do
|
||||
|
@ -86,11 +86,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
|
||||
it "trims content if it's getting too long" do
|
||||
# testing truncation on 800k tokens is slow use model with less
|
||||
context = DialectContext.new(described_class, "gemini-pro")
|
||||
model.max_prompt_tokens = 16_384
|
||||
context = DialectContext.new(described_class, model)
|
||||
translated = context.long_user_input_scenario(length: 5_000)
|
||||
|
||||
expect(translated[:messages].last[:role]).to eq("user")
|
||||
expect(translated[:messages].last.dig(:parts, :text).length).to be <
|
||||
expect(translated[:messages].last.dig(:parts, 0, :text).length).to be <
|
||||
context.long_message_text(length: 5_000).length
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,16 +3,7 @@ require_relative "endpoint_compliance"
|
|||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
let(:url) { "https://api.anthropic.com/v1/messages" }
|
||||
fab!(:model) do
|
||||
Fabricate(
|
||||
:llm_model,
|
||||
url: "https://api.anthropic.com/v1/messages",
|
||||
name: "claude-3-opus",
|
||||
provider: "anthropic",
|
||||
api_key: "123",
|
||||
vision_enabled: true,
|
||||
)
|
||||
end
|
||||
fab!(:model) { Fabricate(:anthropic_model, name: "claude-3-opus", vision_enabled: true) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
|
||||
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||
let(:upload100x100) do
|
||||
|
@ -204,6 +195,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
end
|
||||
|
||||
it "supports non streaming tool calls" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
|
||||
|
||||
tool = {
|
||||
name: "calculate",
|
||||
description: "calculate something",
|
||||
|
@ -224,8 +217,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
tools: [tool],
|
||||
)
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
|
||||
|
||||
body = {
|
||||
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||
type: "message",
|
||||
|
@ -252,7 +243,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
stub_request(:post, url).to_return(body: body)
|
||||
|
||||
result = proxy.generate(prompt, user: Discourse.system_user)
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT.strip
|
||||
<function_calls>
|
||||
|
@ -370,7 +361,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
result = proxy.generate(prompt, user: Discourse.system_user)
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
|
|
|
@ -8,9 +8,10 @@ class BedrockMock < EndpointMock
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
subject(:endpoint) { described_class.new(model) }
|
||||
|
||||
fab!(:user)
|
||||
fab!(:model) { Fabricate(:bedrock_model) }
|
||||
|
||||
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
||||
|
||||
|
@ -25,16 +26,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
describe "function calling" do
|
||||
it "supports old school xml function calls" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
incomplete_tool_call = <<~XML.strip
|
||||
<thinking>I should be ignored</thinking>
|
||||
|
@ -112,7 +107,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
end
|
||||
|
||||
it "supports streaming function calls" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
request = nil
|
||||
|
||||
|
@ -124,7 +119,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "claude-3-haiku-20240307",
|
||||
model: "claude-3-sonnet-20240307",
|
||||
stop_sequence: nil,
|
||||
usage: {
|
||||
input_tokens: 840,
|
||||
|
@ -281,9 +276,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
end
|
||||
end
|
||||
|
||||
describe "Claude 3 Sonnet support" do
|
||||
it "supports the sonnet model" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
describe "Claude 3 support" do
|
||||
it "supports regular completions" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
request = nil
|
||||
|
||||
|
@ -325,8 +320,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(log.response_tokens).to eq(20)
|
||||
end
|
||||
|
||||
it "supports claude 3 sonnet streaming" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
it "supports claude 3 streaming" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
request = nil
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
require_relative "endpoint_compliance"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") }
|
||||
fab!(:cohere_model)
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{cohere_model.id}") }
|
||||
fab!(:user)
|
||||
|
||||
let(:prompt) do
|
||||
|
@ -57,8 +58,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
prompt
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
||||
|
||||
it "is able to trigger a tool" do
|
||||
body = (<<~TEXT).strip
|
||||
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
||||
|
@ -184,7 +183,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
expect(audit.request_tokens).to eq(17)
|
||||
expect(audit.response_tokens).to eq(22)
|
||||
|
||||
expect(audit.language_model).to eq("Cohere - command-r-plus")
|
||||
expect(audit.language_model).to eq("command-r-plus")
|
||||
end
|
||||
|
||||
it "is able to perform streaming completions" do
|
||||
|
|
|
@ -158,7 +158,7 @@ class EndpointsCompliance
|
|||
end
|
||||
|
||||
def dialect(prompt: generic_prompt)
|
||||
dialect_klass.new(prompt, endpoint.model)
|
||||
dialect_klass.new(prompt, endpoint.llm_model)
|
||||
end
|
||||
|
||||
def regular_mode_simple_prompt(mock)
|
||||
|
@ -176,7 +176,7 @@ class EndpointsCompliance
|
|||
expect(log.raw_request_payload).to be_present
|
||||
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
|
||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
|
||||
expect(log.response_tokens).to eq(endpoint.llm_model.tokenizer_class.size(completion_response))
|
||||
end
|
||||
|
||||
def regular_mode_tools(mock)
|
||||
|
@ -206,7 +206,7 @@ class EndpointsCompliance
|
|||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||
expect(log.response_tokens).to eq(
|
||||
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
|
||||
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -128,18 +128,9 @@ class GeminiMock < EndpointMock
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
subject(:endpoint) { described_class.new(model) }
|
||||
|
||||
fab!(:model) do
|
||||
Fabricate(
|
||||
:llm_model,
|
||||
url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest",
|
||||
name: "gemini-1.5-pro",
|
||||
provider: "google",
|
||||
api_key: "ABC",
|
||||
vision_enabled: true,
|
||||
)
|
||||
end
|
||||
fab!(:model) { Fabricate(:gemini_model, vision_enabled: true) }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
|
@ -168,7 +159,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
req_body = nil
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:generateContent?key=ABC"
|
||||
url = "#{model.url}:generateContent?key=123"
|
||||
|
||||
stub_request(:post, url).with(
|
||||
body:
|
||||
|
@ -221,7 +212,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
split = data.split("|")
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=ABC"
|
||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||
|
||||
output = +""
|
||||
gemini_mock.with_chunk_array_support do
|
||||
|
|
|
@ -22,7 +22,7 @@ class HuggingFaceMock < EndpointMock
|
|||
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||
.with(body: request_body(prompt))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
@ -40,7 +40,7 @@ class HuggingFaceMock < EndpointMock
|
|||
end
|
||||
|
||||
def stub_raw(chunks)
|
||||
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
|
||||
WebMock.stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
|
@ -59,7 +59,7 @@ class HuggingFaceMock < EndpointMock
|
|||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||
.with(body: request_body(prompt, stream: true))
|
||||
.to_return(status: 200, body: chunks)
|
||||
|
||||
|
@ -71,8 +71,7 @@ class HuggingFaceMock < EndpointMock
|
|||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap do |b|
|
||||
b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
b[:max_tokens] = 63_991
|
||||
b[:stream] = true if stream
|
||||
end
|
||||
.to_json
|
||||
|
@ -80,15 +79,9 @@ class HuggingFaceMock < EndpointMock
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
subject(:endpoint) do
|
||||
described_class.new(
|
||||
"mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
)
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||
subject(:endpoint) { described_class.new(hf_model) }
|
||||
|
||||
fab!(:hf_model)
|
||||
fab!(:user)
|
||||
|
||||
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
|
||||
|
|
|
@ -146,11 +146,10 @@ class OpenAiMock < EndpointMock
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||
subject(:endpoint) do
|
||||
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
|
||||
end
|
||||
subject(:endpoint) { described_class.new(model) }
|
||||
|
||||
fab!(:user)
|
||||
fab!(:model) { Fabricate(:llm_model) }
|
||||
|
||||
let(:echo_tool) do
|
||||
{
|
||||
|
@ -175,7 +174,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
describe "repeat calls" do
|
||||
it "can properly reset context" do
|
||||
llm = DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo")
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
tools = [
|
||||
{
|
||||
|
@ -258,7 +257,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
describe "image support" do
|
||||
it "can handle images" do
|
||||
model = Fabricate(:llm_model, provider: "open_ai", vision_enabled: true)
|
||||
model = Fabricate(:llm_model, vision_enabled: true)
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
|
|
|
@ -22,7 +22,7 @@ class VllmMock < EndpointMock
|
|||
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
|
||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||
.with(body: model.default_options.merge(messages: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
@ -50,19 +50,16 @@ class VllmMock < EndpointMock
|
|||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
|
||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||
subject(:endpoint) do
|
||||
described_class.new(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
)
|
||||
end
|
||||
subject(:endpoint) { described_class.new(llm_model) }
|
||||
|
||||
fab!(:llm_model) { Fabricate(:vllm_model) }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
|
@ -78,15 +75,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
end
|
||||
|
||||
let(:dialect) do
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, model_name)
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, llm_model)
|
||||
end
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
||||
|
||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
|
|
|
@ -5,12 +5,13 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
described_class.new(
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||
canned_response,
|
||||
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
||||
model,
|
||||
gateway: canned_response,
|
||||
)
|
||||
end
|
||||
|
||||
fab!(:user)
|
||||
fab!(:model) { Fabricate(:llm_model) }
|
||||
|
||||
describe ".proxy" do
|
||||
it "raises an exception when we can't proxy the model" do
|
||||
|
@ -46,7 +47,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
)
|
||||
result = +""
|
||||
described_class
|
||||
.proxy("open_ai:gpt-3.5-turbo")
|
||||
.proxy("custom:#{model.id}")
|
||||
.generate(prompt, user: user) { |partial| result << partial }
|
||||
|
||||
expect(result).to eq("Hello")
|
||||
|
@ -57,12 +58,14 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
end
|
||||
|
||||
describe "#generate with fake model" do
|
||||
fab!(:fake_model)
|
||||
|
||||
before do
|
||||
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
||||
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
|
||||
end
|
||||
|
||||
let(:llm) { described_class.proxy("fake:fake") }
|
||||
let(:llm) { described_class.proxy("custom:#{fake_model.id}") }
|
||||
|
||||
let(:prompt) do
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
|
|
|
@ -5,6 +5,8 @@ return if !defined?(DiscourseAutomation)
|
|||
describe DiscourseAutomation do
|
||||
let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) }
|
||||
|
||||
fab!(:llm_model)
|
||||
|
||||
fab!(:user)
|
||||
fab!(:post)
|
||||
|
||||
|
@ -22,7 +24,7 @@ describe DiscourseAutomation do
|
|||
it "can trigger via automation" do
|
||||
add_automation_field("sender", user.username, type: "user")
|
||||
add_automation_field("receivers", [user.username], type: "users")
|
||||
add_automation_field("model", "gpt-4-turbo")
|
||||
add_automation_field("model", "custom:#{llm_model.id}")
|
||||
add_automation_field("title", "Weekly report")
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
||||
|
@ -36,7 +38,7 @@ describe DiscourseAutomation do
|
|||
it "can target a topic" do
|
||||
add_automation_field("sender", user.username, type: "user")
|
||||
add_automation_field("topic_id", "#{post.topic_id}")
|
||||
add_automation_field("model", "gpt-4-turbo")
|
||||
add_automation_field("model", "custom:#{llm_model.id}")
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
||||
automation.trigger!
|
||||
|
|
|
@ -8,6 +8,8 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
|
||||
let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) }
|
||||
|
||||
fab!(:llm_model)
|
||||
|
||||
def add_automation_field(name, value, type: "text")
|
||||
automation.fields.create!(
|
||||
component: type,
|
||||
|
@ -23,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
SiteSetting.tagging_enabled = true
|
||||
add_automation_field("system_prompt", "hello %%POST%%")
|
||||
add_automation_field("search_for_text", "bad")
|
||||
add_automation_field("model", "gpt-4")
|
||||
add_automation_field("model", "custom:#{llm_model.id}")
|
||||
add_automation_field("category", category.id, type: "category")
|
||||
add_automation_field("tags", %w[aaa bbb], type: "tags")
|
||||
add_automation_field("hide_topic", true, type: "boolean")
|
||||
|
|
|
@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
|||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-4") }
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_4.name) }
|
||||
|
||||
let!(:user) { Fabricate(:user) }
|
||||
|
||||
|
@ -38,7 +38,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
|||
toggle_enabled_bots(bots: [fake])
|
||||
Group.refresh_automatic_groups!
|
||||
|
||||
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model("fake")
|
||||
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model(fake.name)
|
||||
AiPersona.create!(
|
||||
name: "TestPersona",
|
||||
top_p: 0.5,
|
||||
|
|
|
@ -336,6 +336,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
context "when RAG is running with a question consolidator" do
|
||||
let(:consolidated_question) { "what is the time in france?" }
|
||||
|
||||
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||
|
||||
it "will run the question consolidator" do
|
||||
context_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
|
@ -350,7 +352,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
name: "custom",
|
||||
rag_conversation_chunks: 3,
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
question_consolidator_llm: "fake:fake",
|
||||
question_consolidator_llm: "custom:#{llm_model.id}",
|
||||
)
|
||||
|
||||
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
|
||||
|
|
|
@ -4,6 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
subject(:playground) { described_class.new(bot) }
|
||||
|
||||
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") }
|
||||
fab!(:opus_model) { Fabricate(:anthropic_model) }
|
||||
|
||||
fab!(:bot_user) do
|
||||
toggle_enabled_bots(bots: [claude_2])
|
||||
|
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
system_prompt: "You are a helpful bot",
|
||||
vision_enabled: true,
|
||||
vision_max_pixels: 1_000,
|
||||
default_llm: "anthropic:claude-3-opus",
|
||||
default_llm: "custom:#{opus_model.id}",
|
||||
mentionable: true,
|
||||
)
|
||||
end
|
||||
|
@ -211,7 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
)
|
||||
|
||||
persona.create_user!
|
||||
persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
|
||||
persona.update!(default_llm: "custom:#{claude_2.id}", mentionable: true)
|
||||
persona
|
||||
end
|
||||
|
||||
|
@ -228,7 +229,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
SiteSetting.ai_bot_enabled = true
|
||||
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
|
||||
Group.refresh_automatic_groups!
|
||||
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
|
||||
persona.update!(allow_chat: true, mentionable: true, default_llm: "custom:#{opus_model.id}")
|
||||
end
|
||||
|
||||
it "should behave in a sane way when threading is enabled" do
|
||||
|
@ -342,7 +343,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
persona.update!(
|
||||
allow_chat: true,
|
||||
mentionable: false,
|
||||
default_llm: "anthropic:claude-3-opus",
|
||||
default_llm: "custom:#{opus_model.id}",
|
||||
)
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
|
@ -517,7 +518,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
["Magic title", "Yes I can"],
|
||||
llm: "anthropic:claude-2",
|
||||
llm: "custom:#{claude_2.id}",
|
||||
) do
|
||||
post =
|
||||
create_post(
|
||||
|
@ -552,7 +553,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
# title is queued first, ensures it uses the llm targeted via target_usernames not claude
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
["Magic title", "Yes I can"],
|
||||
llm: "open_ai:gpt-3.5-turbo",
|
||||
llm: "custom:#{gpt_35_turbo.id}",
|
||||
) do
|
||||
post =
|
||||
create_post(
|
||||
|
@ -584,7 +585,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
# replies as correct persona if replying direct to persona
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
["Another reply"],
|
||||
llm: "open_ai:gpt-3.5-turbo",
|
||||
llm: "custom:#{gpt_35_turbo.id}",
|
||||
) do
|
||||
create_post(
|
||||
raw: "Please ignore this bot, I am replying to a user",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{Fabricate(:fake_model).id}") }
|
||||
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
|
||||
|
||||
fab!(:user)
|
||||
|
|
|
@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
|||
SiteSetting.ai_openai_api_key = "abc"
|
||||
end
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
let(:dall_e) do
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
describe "#process" do
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# frozen_string_literal: true
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
SiteSetting.ai_openai_api_key = "asd"
|
||||
end
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
require "rails_helper"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||
fab!(:llm_model)
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
let(:tool) do
|
||||
described_class.new(
|
||||
|
|
|
@ -4,7 +4,8 @@ require "rails_helper"
|
|||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||
let(:bot_user) { Fabricate(:user) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||
fab!(:llm_model)
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
|
||||
|
||||
context "with #sort_and_shorten_diff" do
|
||||
|
|
|
@ -4,7 +4,8 @@ require "rails_helper"
|
|||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||
let(:bot_user) { Fabricate(:user) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||
fab!(:llm_model)
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
|
||||
|
||||
context "with valid search results" do
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
require "rails_helper"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
||||
fab!(:llm_model)
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
let(:tool) do
|
||||
described_class.new(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
|||
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
|
||||
|
||||
describe "#process" do
|
||||
it "can generate correct info" do
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
|
||||
|
||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||
fab!(:gpt_35_bot) { Fabricate(:llm_model, name: "gpt-3.5-turbo") }
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
toggle_enabled_bots(bots: [gpt_35_bot])
|
||||
toggle_enabled_bots(bots: [llm_model])
|
||||
end
|
||||
|
||||
def search_settings(query)
|
||||
|
|
|
@ -4,10 +4,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
|||
before { SearchIndexer.enable }
|
||||
after { SearchIndexer.disable }
|
||||
|
||||
before { SiteSetting.ai_openai_api_key = "asd" }
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
fab!(:admin)
|
||||
|
|
|
@ -9,8 +9,10 @@ def has_rg?
|
|||
end
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:progress_blk) { Proc.new {} }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") }
|
||||
fab!(:llm_model)
|
||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_openai_api_key = "asd"
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
end
|
||||
before { SiteSetting.ai_bot_enabled = true }
|
||||
|
||||
describe "#invoke" do
|
||||
it "can retrieve the content of a webpage and returns the processed text" do
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
describe DiscourseAi::Automation::LlmTriage do
|
||||
fab!(:post)
|
||||
fab!(:llm_model)
|
||||
|
||||
def triage(**args)
|
||||
DiscourseAi::Automation::LlmTriage.handle(**args)
|
||||
|
@ -10,7 +11,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
hide_topic: true,
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
|
@ -24,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
hide_topic: true,
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
|
@ -40,7 +41,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_id: category.id,
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
|
@ -55,7 +56,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
canned_reply: "test canned reply 123",
|
||||
|
@ -73,7 +74,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
flag_post: true,
|
||||
|
@ -89,7 +90,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
|||
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
|
||||
triage(
|
||||
post: post,
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
system_prompt: "test %%POST%%",
|
||||
search_for_text: "bad",
|
||||
flag_post: true,
|
||||
|
|
|
@ -32,6 +32,8 @@ module DiscourseAi
|
|||
fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) }
|
||||
fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) }
|
||||
|
||||
fab!(:llm_model)
|
||||
|
||||
describe "#run!" do
|
||||
it "is able to generate email reports" do
|
||||
freeze_time
|
||||
|
@ -41,7 +43,7 @@ module DiscourseAi
|
|||
sender_username: user.username,
|
||||
receivers: ["fake@discourse.com"],
|
||||
title: "test report %DATE%",
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_ids: nil,
|
||||
tags: nil,
|
||||
allow_secure_categories: false,
|
||||
|
@ -78,7 +80,7 @@ module DiscourseAi
|
|||
sender_username: user.username,
|
||||
receivers: [receiver.username],
|
||||
title: "test report",
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_ids: nil,
|
||||
tags: nil,
|
||||
allow_secure_categories: false,
|
||||
|
@ -123,7 +125,7 @@ module DiscourseAi
|
|||
sender_username: user.username,
|
||||
receivers: [receiver.username],
|
||||
title: "test report",
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_ids: nil,
|
||||
tags: nil,
|
||||
allow_secure_categories: false,
|
||||
|
@ -166,7 +168,7 @@ module DiscourseAi
|
|||
sender_username: user.username,
|
||||
receivers: [receiver.username],
|
||||
title: "test report",
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_ids: nil,
|
||||
tags: nil,
|
||||
allow_secure_categories: false,
|
||||
|
@ -194,7 +196,7 @@ module DiscourseAi
|
|||
sender_username: user.username,
|
||||
receivers: [receiver.username],
|
||||
title: "test report",
|
||||
model: "gpt-4",
|
||||
model: "custom:#{llm_model.id}",
|
||||
category_ids: nil,
|
||||
tags: nil,
|
||||
allow_secure_categories: false,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
RSpec.describe AiTool do
|
||||
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy_from_obj(llm_model) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
|
||||
def create_tool(parameters: nil, script: nil)
|
||||
AiTool.create!(
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe LlmModel do
|
||||
describe ".seed_srv_backed_model" do
|
||||
before do
|
||||
SiteSetting.ai_vllm_endpoint_srv = "srv.llm.service."
|
||||
SiteSetting.ai_vllm_api_key = "123"
|
||||
end
|
||||
|
||||
context "when the model doesn't exist yet" do
|
||||
it "creates it" do
|
||||
described_class.seed_srv_backed_model
|
||||
|
||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
||||
|
||||
expect(llm_model).to be_present
|
||||
expect(llm_model.name).to eq("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
expect(llm_model.api_key).to eq(SiteSetting.ai_vllm_api_key)
|
||||
end
|
||||
end
|
||||
|
||||
context "when the model already exists" do
|
||||
before { described_class.seed_srv_backed_model }
|
||||
|
||||
context "when the API key setting changes" do
|
||||
it "updates it" do
|
||||
new_key = "456"
|
||||
SiteSetting.ai_vllm_api_key = new_key
|
||||
|
||||
described_class.seed_srv_backed_model
|
||||
|
||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
||||
|
||||
expect(llm_model.api_key).to eq(new_key)
|
||||
end
|
||||
end
|
||||
|
||||
context "when the SRV is no longer defined" do
|
||||
it "deletes the LlmModel" do
|
||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
||||
expect(llm_model).to be_present
|
||||
|
||||
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
|
||||
|
||||
expect { llm_model.reload }.to raise_exception(ActiveRecord::RecordNotFound)
|
||||
end
|
||||
|
||||
it "disabled the bot user" do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
||||
llm_model.update!(enabled_chat_bot: true)
|
||||
llm_model.toggle_companion_user
|
||||
user = llm_model.user
|
||||
|
||||
expect(user).to be_present
|
||||
|
||||
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
|
||||
|
||||
expect { user.reload }.to raise_exception(ActiveRecord::RecordNotFound)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -8,7 +8,7 @@ module DiscourseAi::ChatBotHelper
|
|||
end
|
||||
|
||||
def assign_fake_provider_to(setting_name)
|
||||
Fabricate(:llm_model, provider: "fake", name: "fake").tap do |fake_llm|
|
||||
Fabricate(:fake_model).tap do |fake_llm|
|
||||
SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}")
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue