DEV: Remove old code now that features rely on LlmModels. (#729)

* DEV: Remove old code now that features rely on LlmModels.

* Hide old settings and migrate persona llm overrides

* Remove shadowing special URL + seeding code. Use srv:// prefix instead.
This commit is contained in:
Roman Rizzi 2024-07-30 13:44:57 -03:00 committed by GitHub
parent 73a2b15e91
commit bed044448c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 439 additions and 813 deletions

View File

@ -110,9 +110,7 @@ module DiscourseAi
) )
provider = updating ? updating.provider : permitted[:provider] provider = updating ? updating.provider : permitted[:provider]
permit_url = permit_url = provider != LlmModel::BEDROCK_PROVIDER_NAME
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
provider != LlmModel::BEDROCK_PROVIDER_NAME
permitted[:url] = params.dig(:ai_llm, :url) if permit_url permitted[:url] = params.dig(:ai_llm, :url) if permit_url

View File

@ -2,44 +2,10 @@
class LlmModel < ActiveRecord::Base class LlmModel < ActiveRecord::Base
FIRST_BOT_USER_ID = -1200 FIRST_BOT_USER_ID = -1200
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
BEDROCK_PROVIDER_NAME = "aws_bedrock" BEDROCK_PROVIDER_NAME = "aws_bedrock"
belongs_to :user belongs_to :user
validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] }, if: :url_changed?
def self.seed_srv_backed_model
srv = SiteSetting.ai_vllm_endpoint_srv
srv_model = find_by(url: RESERVED_VLLM_SRV_URL)
if srv.present?
if srv_model.present?
current_key = SiteSetting.ai_vllm_api_key
srv_model.update!(api_key: current_key) if current_key != srv_model.api_key
else
record =
new(
display_name: "vLLM SRV LLM",
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
provider: "vllm",
tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer",
url: RESERVED_VLLM_SRV_URL,
max_prompt_tokens: 8000,
api_key: SiteSetting.ai_vllm_api_key,
)
record.save(validate: false) # Ignore reserved URL validation
end
else
# Clean up companion users
srv_model&.enabled_chat_bot = false
srv_model&.toggle_companion_user
srv_model&.destroy!
end
end
def self.provider_params def self.provider_params
{ {
aws_bedrock: { aws_bedrock: {
@ -54,7 +20,7 @@ class LlmModel < ActiveRecord::Base
end end
def to_llm def to_llm
DiscourseAi::Completions::Llm.proxy_from_obj(self) DiscourseAi::Completions::Llm.proxy("custom:#{id}")
end end
def toggle_companion_user def toggle_companion_user

View File

@ -19,6 +19,6 @@ class LlmModelSerializer < ApplicationSerializer
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object
def shadowed_by_srv def shadowed_by_srv
object.url == LlmModel::RESERVED_VLLM_SRV_URL object.url.to_s.starts_with?("srv://")
end end
end end

View File

@ -60,17 +60,9 @@ export default class AiLlmEditorForm extends Component {
return this.testRunning || this.testResult !== null; return this.testRunning || this.testResult !== null;
} }
get displaySRVWarning() {
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
}
get canEditURL() { get canEditURL() {
// Explicitly false. // Explicitly false.
if (this.metaProviderParams.url_editable === false) { return this.metaProviderParams.url_editable !== false;
return false;
}
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
} }
@computed("args.model.provider") @computed("args.model.provider")
@ -174,12 +166,6 @@ export default class AiLlmEditorForm extends Component {
} }
<template> <template>
{{#if this.displaySRVWarning}}
<div class="alert alert-info">
{{icon "exclamation-circle"}}
{{I18n.t "discourse_ai.llms.srv_warning"}}
</div>
{{/if}}
<form class="form-horizontal ai-llm-editor"> <form class="form-horizontal ai-llm-editor">
<div class="control-group"> <div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label> <label>{{i18n "discourse_ai.llms.display_name"}}</label>

View File

@ -237,7 +237,6 @@ en:
confirm_delete: Are you sure you want to delete this model? confirm_delete: Are you sure you want to delete this model?
delete: Delete delete: Delete
srv_warning: This LLM points to an SRV record, and its URL is not editable. You have to update the hidden "ai_vllm_endpoint_srv" setting instead.
preconfigured_llms: "Select your LLM" preconfigured_llms: "Select your LLM"
preconfigured: preconfigured:
none: "Configure manually..." none: "Configure manually..."

View File

@ -96,21 +96,35 @@ discourse_ai:
- opennsfw2 - opennsfw2
- nsfw_detector - nsfw_detector
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_url:
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" default: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions" hidden: true
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_16k_url:
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions" default: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions" hidden: true
ai_openai_gpt4o_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_32k_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_turbo_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations" ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings" ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
ai_openai_organization: "" ai_openai_organization:
default: ""
hidden: true
ai_openai_api_key: ai_openai_api_key:
default: "" default: ""
secret: true secret: true
ai_anthropic_api_key: ai_anthropic_api_key:
default: "" default: ""
secret: true hidden: true
ai_anthropic_native_tool_call_models: ai_anthropic_native_tool_call_models:
type: list type: list
list_type: compact list_type: compact
@ -123,7 +137,7 @@ discourse_ai:
- claude-3-5-sonnet - claude-3-5-sonnet
ai_cohere_api_key: ai_cohere_api_key:
default: "" default: ""
secret: true hidden: true
ai_stability_api_key: ai_stability_api_key:
default: "" default: ""
secret: true secret: true
@ -140,13 +154,16 @@ discourse_ai:
- "stable-diffusion-v1-5" - "stable-diffusion-v1-5"
ai_hugging_face_api_url: ai_hugging_face_api_url:
default: "" default: ""
hidden: true
ai_hugging_face_api_key: ai_hugging_face_api_key:
default: "" default: ""
secret: true hidden: true
ai_hugging_face_token_limit: ai_hugging_face_token_limit:
default: 4096 default: 4096
hidden: true
ai_hugging_face_model_display_name: ai_hugging_face_model_display_name:
default: "" default: ""
hidden: true
ai_hugging_face_tei_endpoint: ai_hugging_face_tei_endpoint:
default: "" default: ""
ai_hugging_face_tei_endpoint_srv: ai_hugging_face_tei_endpoint_srv:
@ -167,11 +184,13 @@ discourse_ai:
ai_bedrock_access_key_id: ai_bedrock_access_key_id:
default: "" default: ""
secret: true secret: true
hidden: true
ai_bedrock_secret_access_key: ai_bedrock_secret_access_key:
default: "" default: ""
secret: true hidden: true
ai_bedrock_region: ai_bedrock_region:
default: "us-east-1" default: "us-east-1"
hidden: true
ai_cloudflare_workers_account_id: ai_cloudflare_workers_account_id:
default: "" default: ""
secret: true secret: true
@ -180,13 +199,16 @@ discourse_ai:
secret: true secret: true
ai_gemini_api_key: ai_gemini_api_key:
default: "" default: ""
secret: true hidden: true
ai_vllm_endpoint: ai_vllm_endpoint:
default: "" default: ""
hidden: true
ai_vllm_endpoint_srv: ai_vllm_endpoint_srv:
default: "" default: ""
hidden: true hidden: true
ai_vllm_api_key: "" ai_vllm_api_key:
default: ""
hidden: true
ai_llava_endpoint: ai_llava_endpoint:
default: "" default: ""
hidden: true hidden: true

View File

@ -1,8 +0,0 @@
# frozen_string_literal: true
begin
LlmModel.seed_srv_backed_model
rescue PG::UndefinedColumn => e
# If this code runs before migrations, an attribute might be missing.
Rails.logger.warn("Failed to seed SRV-Backed LLM: #{e.meesage}")
end

View File

@ -25,8 +25,9 @@ class MigrateVisionLlms < ActiveRecord::Migration[7.1]
).first ).first
if current_value && current_value != "llava" if current_value && current_value != "llava"
model_name = current_value.split(":").last
llm_model = llm_model =
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: current_value).first DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: model_name).first
if llm_model if llm_model
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
class MigratePersonaLlmOverride < ActiveRecord::Migration[7.1]
def up
fields_to_update = DB.query(<<~SQL)
SELECT id, default_llm
FROM ai_personas
WHERE default_llm IS NOT NULL
SQL
return if fields_to_update.empty?
updated_fields =
fields_to_update
.map do |field|
llm_model_id = matching_llm_model(field.default_llm)
"(#{field.id}, 'custom:#{llm_model_id}')" if llm_model_id
end
.compact
return if updated_fields.empty?
DB.exec(<<~SQL)
UPDATE ai_personas
SET default_llm = new_fields.new_default_llm
FROM (VALUES #{updated_fields.join(", ")}) AS new_fields(id, new_default_llm)
WHERE new_fields.id::bigint = ai_personas.id
SQL
end
def matching_llm_model(model)
provider = model.split(":").first
model_name = model.split(":").last
return if provider == "custom"
DB.query_single(
"SELECT id FROM llm_models WHERE name = :name AND provider = :provider",
{ name: model_name, provider: provider },
).first
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -5,19 +5,15 @@ module DiscourseAi
module Dialects module Dialects
class ChatGpt < Dialect class ChatGpt < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(model_provider)
model_name.starts_with?("gpt-") model_provider == "open_ai" || model_provider == "azure"
end end
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def native_tool_support? def native_tool_support?
true llm_model.provider == "open_ai" || llm_model.provider == "azure"
end end
def translate def translate
@ -30,19 +26,17 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not # provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard # 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens] || 2500) + 50 buffer = (opts[:max_tokens] || 2500) + 50
if tools.present? if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation # note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.tokenizer.size(tools.to_json.to_s) @function_size ||= llm_model.tokenizer_class.size(tools.to_json.to_s)
buffer += @function_size buffer += @function_size
end end
model_max_tokens - buffer llm_model.max_prompt_tokens - buffer
end end
private private
@ -105,24 +99,7 @@ module DiscourseAi
end end
def calculate_message_token(context) def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s) llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end
def model_max_tokens
case model_name
when "gpt-3.5-turbo-16k"
16_384
when "gpt-4"
8192
when "gpt-4-32k"
32_768
when "gpt-4-turbo"
131_072
when "gpt-4o"
131_072
else
8192
end
end end
end end
end end

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class Claude < Dialect class Claude < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(provider_name)
model_name.start_with?("claude") || model_name.start_with?("anthropic") provider_name == "anthropic" || provider_name == "aws_bedrock"
end end
end end
@ -26,10 +26,6 @@ module DiscourseAi
end end
end end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
end
def translate def translate
messages = super messages = super
@ -61,14 +57,11 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens llm_model.max_prompt_tokens
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
end end
def native_tool_support? def native_tool_support?
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name) SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name)
end end
private private

View File

@ -6,18 +6,12 @@ module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class Command < Dialect class Command < Dialect
class << self def self.can_translate?(model_provider)
def can_translate?(model_name) model_provider == "cohere"
%w[command-light command command-r command-r-plus].include?(model_name)
end
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def translate def translate
messages = super messages = super
@ -68,20 +62,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens llm_model.max_prompt_tokens
case model_name
when "command-light"
4096
when "command"
8192
when "command-r"
131_072
when "command-r-plus"
131_072
else
8192
end
end end
def native_tool_support? def native_tool_support?
@ -99,7 +80,7 @@ module DiscourseAi
end end
def calculate_message_token(context) def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s) llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end end
def system_msg(msg) def system_msg(msg)

View File

@ -5,7 +5,7 @@ module DiscourseAi
module Dialects module Dialects
class Dialect class Dialect
class << self class << self
def can_translate?(_model_name) def can_translate?(model_provider)
raise NotImplemented raise NotImplemented
end end
@ -19,7 +19,7 @@ module DiscourseAi
] ]
end end
def dialect_for(model_name) def dialect_for(model_provider)
dialects = [] dialects = []
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?
@ -28,26 +28,21 @@ module DiscourseAi
dialects = dialects.concat(all_dialects) dialects = dialects.concat(all_dialects)
dialect = dialects.find { |d| d.can_translate?(model_name) } dialect = dialects.find { |d| d.can_translate?(model_provider) }
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
dialect dialect
end end
end end
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil) def initialize(generic_prompt, llm_model, opts: {})
@prompt = generic_prompt @prompt = generic_prompt
@model_name = model_name
@opts = opts @opts = opts
@llm_model = llm_model @llm_model = llm_model
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
raise NotImplemented
end
def can_end_with_assistant_msg? def can_end_with_assistant_msg?
false false
end end
@ -57,7 +52,7 @@ module DiscourseAi
end end
def vision_support? def vision_support?
llm_model&.vision_enabled? llm_model.vision_enabled?
end end
def tools def tools
@ -88,12 +83,12 @@ module DiscourseAi
private private
attr_reader :model_name, :opts, :llm_model attr_reader :opts, :llm_model
def trim_messages(messages) def trim_messages(messages)
prompt_limit = max_prompt_tokens prompt_limit = max_prompt_tokens
current_token_count = 0 current_token_count = 0
message_step_size = (max_prompt_tokens / 25).to_i * -1 message_step_size = (prompt_limit / 25).to_i * -1
trimmed_messages = [] trimmed_messages = []
@ -157,7 +152,7 @@ module DiscourseAi
end end
def calculate_message_token(msg) def calculate_message_token(msg)
self.tokenizer.size(msg[:content].to_s) llm_model.tokenizer_class.size(msg[:content].to_s)
end end
def tools_dialect def tools_dialect

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class Gemini < Dialect class Gemini < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(model_provider)
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name) model_provider == "google"
end end
end end
@ -14,10 +14,6 @@ module DiscourseAi
true true
end end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
def translate def translate
# Gemini complains if we don't alternate model/user roles. # Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } } noop_model_response = { role: "model", parts: { text: "Ok." } }
@ -74,24 +70,17 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens llm_model.max_prompt_tokens
if model_name.start_with?("gemini-1.5")
# technically we support 1 million tokens, but we're being conservative
800_000
else
16_384 # 50% of model tokens
end
end end
protected protected
def calculate_message_token(context) def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s) llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end end
def beta_api? def beta_api?
@beta_api ||= model_name.start_with?("gemini-1.5") @beta_api ||= llm_model.name.start_with?("gemini-1.5")
end end
def system_msg(msg) def system_msg(msg)

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Anthropic < Base class Anthropic < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "anthropic"
endpoint_name == "anthropic"
end
def dependant_setting_names
%w[ai_anthropic_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_anthropic_api_key.present?
end
def endpoint_name(model_name)
"Anthropic - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -29,7 +15,7 @@ module DiscourseAi
def default_options(dialect) def default_options(dialect)
mapped_model = mapped_model =
case model case llm_model.name
when "claude-2" when "claude-2"
"claude-2.1" "claude-2.1"
when "claude-instant-1" when "claude-instant-1"
@ -43,7 +29,7 @@ module DiscourseAi
when "claude-3-5-sonnet" when "claude-3-5-sonnet"
"claude-3-5-sonnet-20240620" "claude-3-5-sonnet-20240620"
else else
model llm_model.name
end end
options = { model: mapped_model, max_tokens: 3_000 } options = { model: mapped_model, max_tokens: 3_000 }
@ -74,9 +60,7 @@ module DiscourseAi
end end
def model_uri def model_uri
url = llm_model&.url || "https://api.anthropic.com/v1/messages" URI(llm_model.url)
URI(url)
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -94,7 +78,7 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { headers = {
"anthropic-version" => "2023-06-01", "anthropic-version" => "2023-06-01",
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key, "x-api-key" => llm_model.api_key,
"content-type" => "application/json", "content-type" => "application/json",
} }

View File

@ -6,24 +6,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class AwsBedrock < Base class AwsBedrock < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "aws_bedrock"
endpoint_name == "aws_bedrock"
end
def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end
def correctly_configured?(_model)
SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
end
def endpoint_name(model_name)
"AWS Bedrock - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -62,37 +46,28 @@ module DiscourseAi
end end
def model_uri def model_uri
if llm_model region = llm_model.lookup_custom_param("region")
region = llm_model.lookup_custom_param("region")
api_url = bedrock_model_id =
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke" case llm_model.name
else when "claude-2"
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html "anthropic.claude-v2:1"
# when "claude-3-haiku"
# FYI there is a 2.0 version of Claude, very little need to support it given "anthropic.claude-3-haiku-20240307-v1:0"
# haiku/sonnet are better fits anyway, we map to claude-2.1 when "claude-3-sonnet"
bedrock_model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
case model when "claude-instant-1"
when "claude-2" "anthropic.claude-instant-v1"
"anthropic.claude-v2:1" when "claude-3-opus"
when "claude-3-haiku" "anthropic.claude-3-opus-20240229-v1:0"
"anthropic.claude-3-haiku-20240307-v1:0" when "claude-3-5-sonnet"
when "claude-3-sonnet" "anthropic.claude-3-5-sonnet-20240620-v1:0"
"anthropic.claude-3-sonnet-20240229-v1:0" else
when "claude-instant-1" llm_model.name
"anthropic.claude-instant-v1" end
when "claude-3-opus"
"anthropic.claude-3-opus-20240229-v1:0"
when "claude-3-5-sonnet"
"anthropic.claude-3-5-sonnet-20240620-v1:0"
else
model
end
api_url = api_url =
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke" "https://bedrock-runtime.#{region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
end
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -114,11 +89,9 @@ module DiscourseAi
signer = signer =
Aws::Sigv4::Signer.new( Aws::Sigv4::Signer.new(
access_key_id: access_key_id: llm_model.lookup_custom_param("access_key_id"),
llm_model&.lookup_custom_param("access_key_id") || region: llm_model.lookup_custom_param("region"),
SiteSetting.ai_bedrock_access_key_id, secret_access_key: llm_model.api_key,
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock", service: "bedrock",
) )

View File

@ -30,39 +30,12 @@ module DiscourseAi
end end
end end
def configuration_hint def can_contact?(_model_provider)
settings = dependant_setting_names
I18n.t(
"discourse_ai.llm.endpoints.configuration_hint",
settings: settings.join(", "),
count: settings.length,
)
end
def display_name(model_name)
to_display = endpoint_name(model_name)
return to_display if correctly_configured?(model_name)
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
end
def dependant_setting_names
raise NotImplementedError
end
def endpoint_name(_model_name)
raise NotImplementedError
end
def can_contact?(_endpoint_name)
raise NotImplementedError raise NotImplementedError
end end
end end
def initialize(model_name, tokenizer, llm_model: nil) def initialize(llm_model)
@model = model_name
@tokenizer = tokenizer
@llm_model = llm_model @llm_model = llm_model
end end
@ -136,7 +109,7 @@ module DiscourseAi
topic_id: dialect.prompt.topic_id, topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id, post_id: dialect.prompt.post_id,
feature_name: feature_name, feature_name: feature_name,
language_model: self.class.endpoint_name(@model), language_model: llm_model.name,
) )
if !@streaming_mode if !@streaming_mode
@ -323,10 +296,14 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt)) tokenizer.size(extract_prompt_for_tokenizer(prompt))
end end
attr_reader :tokenizer, :model, :llm_model attr_reader :llm_model
protected protected
def tokenizer
llm_model.tokenizer_class
end
# should normalize temperature, max_tokens, stop_words to endpoint specific values # should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params) def normalize_model_params(model_params)
raise NotImplementedError raise NotImplementedError

View File

@ -6,10 +6,6 @@ module DiscourseAi
class CannedResponse class CannedResponse
CANNED_RESPONSE_ERROR = Class.new(StandardError) CANNED_RESPONSE_ERROR = Class.new(StandardError)
def self.can_contact?(_)
Rails.env.test?
end
def initialize(responses) def initialize(responses)
@responses = responses @responses = responses
@completions = 0 @completions = 0

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Cohere < Base class Cohere < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "cohere"
endpoint_name == "cohere"
end
def dependant_setting_names
%w[ai_cohere_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_cohere_api_key.present?
end
def endpoint_name(model_name)
"Cohere - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -39,9 +25,7 @@ module DiscourseAi
private private
def model_uri def model_uri
url = llm_model&.url || "https://api.cohere.ai/v1/chat" URI(llm_model.url)
URI(url)
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -59,7 +43,7 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { headers = {
"Content-Type" => "application/json", "Content-Type" => "application/json",
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}", "Authorization" => "Bearer #{llm_model.api_key}",
} }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -4,20 +4,6 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Fake < Base class Fake < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "fake"
end
def correctly_configured?(_model_name)
true
end
def endpoint_name(_model_name)
"Test - fake model"
end
end
STOCK_CONTENT = <<~TEXT STOCK_CONTENT = <<~TEXT
# Discourse Markdown Styles Showcase # Discourse Markdown Styles Showcase
@ -75,6 +61,10 @@ module DiscourseAi
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers. Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
TEXT TEXT
def self.can_contact?(model_provider)
model_provider == "fake"
end
def self.with_fake_content(content) def self.with_fake_content(content)
@fake_content = content @fake_content = content
yield yield

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Gemini < Base class Gemini < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "google"
endpoint_name == "google"
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_gemini_api_key.present?
end
def endpoint_name(model_name)
"Google - #{model_name}"
end
end end
def default_options def default_options
@ -59,21 +45,8 @@ module DiscourseAi
private private
def model_uri def model_uri
if llm_model url = llm_model.url
url = llm_model.url key = llm_model.api_key
else
mapped_model = model
if model == "gemini-1.5-pro"
mapped_model = "gemini-1.5-pro-latest"
elsif model == "gemini-1.5-flash"
mapped_model = "gemini-1.5-flash-latest"
elsif model == "gemini-1.0-pro"
mapped_model = "gemini-pro-latest"
end
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
end
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
if @streaming_mode if @streaming_mode
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse" url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class HuggingFace < Base class HuggingFace < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "hugging_face"
endpoint_name == "hugging_face"
end
def dependant_setting_names
%w[ai_hugging_face_api_url]
end
def correctly_configured?(_model_name)
SiteSetting.ai_hugging_face_api_url.present?
end
def endpoint_name(model_name)
"Hugging Face - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end end
def default_options def default_options
{ model: model, temperature: 0.7 } { model: llm_model.name, temperature: 0.7 }
end end
def provider_id def provider_id
@ -44,7 +30,7 @@ module DiscourseAi
private private
def model_uri def model_uri
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url) URI(llm_model.url)
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)
@ -53,8 +39,7 @@ module DiscourseAi
.merge(messages: prompt) .merge(messages: prompt)
.tap do |payload| .tap do |payload|
if !payload[:max_tokens] if !payload[:max_tokens]
token_limit = token_limit = llm_model.max_prompt_tokens
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
payload[:max_tokens] = token_limit - prompt_size(prompt) payload[:max_tokens] = token_limit - prompt_size(prompt)
end end
@ -64,7 +49,7 @@ module DiscourseAi
end end
def prepare_request(payload) def prepare_request(payload)
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key api_key = llm_model.api_key
headers = headers =
{ "Content-Type" => "application/json" }.tap do |h| { "Content-Type" => "application/json" }.tap do |h|

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Ollama < Base class Ollama < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "ollama"
endpoint_name == "ollama"
end
def dependant_setting_names
%w[ai_ollama_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_ollama_endpoint.present?
end
def endpoint_name(model_name)
"Ollama - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end end
def default_options def default_options
{ max_tokens: 2000, model: model } { max_tokens: 2000, model: llm_model.name }
end end
def provider_id def provider_id
@ -48,7 +34,7 @@ module DiscourseAi
private private
def model_uri def model_uri
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions") URI(llm_model.url)
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)

View File

@ -4,56 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class OpenAi < Base class OpenAi < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) %w[open_ai azure].include?(model_provider)
endpoint_name == "open_ai"
end
def dependant_setting_names
%w[
ai_openai_api_key
ai_openai_gpt4o_url
ai_openai_gpt4_32k_url
ai_openai_gpt4_turbo_url
ai_openai_gpt4_url
ai_openai_gpt4_url
ai_openai_gpt35_16k_url
ai_openai_gpt35_url
]
end
def correctly_configured?(model_name)
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
end
def has_url?(model)
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
if model.include?("1106") || model.include?("turbo")
SiteSetting.ai_openai_gpt4_turbo_url
elsif model.include?("gpt-4o")
SiteSetting.ai_openai_gpt4o_url
else
SiteSetting.ai_openai_gpt4_url
end
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
url.present?
end
def endpoint_name(model_name)
"OpenAI - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -68,7 +20,7 @@ module DiscourseAi
end end
def default_options def default_options
{ model: model } { model: llm_model.name }
end end
def provider_id def provider_id
@ -78,28 +30,7 @@ module DiscourseAi
private private
def model_uri def model_uri
return URI(llm_model.url) if llm_model&.url URI(llm_model.url)
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
if model.include?("1106") || model.include?("turbo")
SiteSetting.ai_openai_gpt4_turbo_url
else
SiteSetting.ai_openai_gpt4_url
end
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
URI(url)
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -110,7 +41,7 @@ module DiscourseAi
# Usage is not available in Azure yet. # Usage is not available in Azure yet.
# We'll fallback to guess this using the tokenizer. # We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if model_uri.host.exclude?("azure") payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end end
payload[:tools] = dialect.tools if dialect.tools.present? payload[:tools] = dialect.tools if dialect.tools.present?
@ -119,19 +50,16 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { "Content-Type" => "application/json" } headers = { "Content-Type" => "application/json" }
api_key = llm_model.api_key
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key if llm_model.provider == "azure"
if model_uri.host.include?("azure")
headers["api-key"] = api_key headers["api-key"] = api_key
else else
headers["Authorization"] = "Bearer #{api_key}" headers["Authorization"] = "Bearer #{api_key}"
org_id = llm_model.lookup_custom_param("organization")
headers["OpenAI-Organization"] = org_id if org_id.present?
end end
org_id =
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
headers["OpenAI-Organization"] = org_id if org_id.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Vllm < Base class Vllm < Base
class << self def self.can_contact?(model_provider)
def can_contact?(endpoint_name) model_provider == "vllm"
endpoint_name == "vllm"
end
def dependant_setting_names
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
end
def endpoint_name(model_name)
"vLLM - #{model_name}"
end
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end end
def default_options def default_options
{ max_tokens: 2000, model: model } { max_tokens: 2000, model: llm_model.name }
end end
def provider_id def provider_id
@ -44,16 +30,13 @@ module DiscourseAi
private private
def model_uri def model_uri
if llm_model&.url && !llm_model&.url == LlmModel::RESERVED_VLLM_SRV_URL if llm_model.url.to_s.starts_with?("srv://")
return URI(llm_model.url) record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
end
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else else
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions" api_endpoint = llm_model.url
end end
@uri ||= URI(api_endpoint) @uri ||= URI(api_endpoint)
end end

View File

@ -64,8 +64,8 @@ module DiscourseAi
id: "open_ai", id: "open_ai",
models: [ models: [
{ name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" }, { name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" },
{ name: "gpt-4o-mini", tokens: 131_072, display_name: "GPT-4 Omni Mini" },
{ name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" }, { name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" },
{ name: "gpt-3.5-turbo", tokens: 16_385, display_name: "GPT-3.5 Turbo" },
], ],
tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer, tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer,
endpoint: "https://api.openai.com/v1/chat/completions", endpoint: "https://api.openai.com/v1/chat/completions",
@ -89,41 +89,6 @@ module DiscourseAi
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name) DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
end end
def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them.
@models_by_provider ||=
{
aws_bedrock: %w[
claude-instant-1
claude-2
claude-3-haiku
claude-3-sonnet
claude-3-opus
],
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
],
cohere: %w[command-light command command-r command-r-plus],
open_ai: %w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
gpt-4o
],
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
}.tap do |h|
h[:ollama] = ["mistral"] if Rails.env.development?
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
end
end
def valid_provider_models def valid_provider_models
return @valid_provider_models if defined?(@valid_provider_models) return @valid_provider_models if defined?(@valid_provider_models)
@ -151,61 +116,38 @@ module DiscourseAi
@prompts << prompt if @prompts @prompts << prompt if @prompts
end end
def proxy(model_name) def proxy(model)
provider_and_model_name = model_name.split(":") llm_model =
provider_name = provider_and_model_name.first if model.is_a?(LlmModel)
model_name_without_prov = provider_and_model_name[1..].join model
else
model_name_without_prov = model.split(":").last.to_i
# We are in the process of transitioning to always use objects here. LlmModel.find_by(id: model_name_without_prov)
# We'll live with this hack for a while.
if provider_name == "custom"
llm_model = LlmModel.find(model_name_without_prov)
raise UNKNOWN_MODEL if !llm_model
return proxy_from_obj(llm_model)
end
dialect_klass =
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
if @canned_response
if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end end
return new(dialect_klass, nil, model_name, gateway: @canned_response) raise UNKNOWN_MODEL if llm_model.nil?
end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) model_provider = llm_model.provider
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
new(dialect_klass, gateway_klass, model_name_without_prov)
end
def proxy_from_obj(llm_model)
provider_name = llm_model.provider
model_name = llm_model.name
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
if @canned_response if @canned_response
if @canned_llm && @canned_llm != [provider_name, model_name].join(":") if @canned_llm && @canned_llm != model
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" raise "Invalid call LLM call, expected #{@canned_llm} but got #{model}"
end end
return( return new(dialect_klass, nil, llm_model, gateway: @canned_response)
new(dialect_klass, nil, model_name, gateway: @canned_response, llm_model: llm_model)
)
end end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model) new(dialect_klass, gateway_klass, llm_model)
end end
end end
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil) def initialize(dialect_klass, gateway_klass, llm_model, gateway: nil)
@dialect_klass = dialect_klass @dialect_klass = dialect_klass
@gateway_klass = gateway_klass @gateway_klass = gateway_klass
@model_name = model_name
@gateway = gateway @gateway = gateway
@llm_model = llm_model @llm_model = llm_model
end end
@ -264,9 +206,9 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model) dialect = dialect_klass.new(prompt, llm_model, opts: model_params)
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model) gateway = @gateway || gateway_klass.new(llm_model)
gateway.perform_completion!( gateway.perform_completion!(
dialect, dialect,
user, user,
@ -277,16 +219,14 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
llm_model&.max_prompt_tokens || llm_model.max_prompt_tokens
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end end
def tokenizer def tokenizer
llm_model&.tokenizer_class || llm_model.tokenizer_class
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
end end
attr_reader :model_name, :llm_model attr_reader :llm_model
private private

View File

@ -15,19 +15,16 @@ module DiscourseAi
return !@parent_enabled return !@parent_enabled
end end
llm_model_id = val.split(":")&.last run_test(val).tap { |result| @unreachable = result }
llm_model = LlmModel.find_by(id: llm_model_id) rescue StandardError => e
return false if llm_model.nil? raise e if Rails.env.test?
run_test(llm_model).tap { |result| @unreachable = result }
rescue StandardError
@unreachable = true @unreachable = true
false false
end end
def run_test(llm_model) def run_test(val)
DiscourseAi::Completions::Llm DiscourseAi::Completions::Llm
.proxy_from_obj(llm_model) .proxy(val)
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator") .generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
.present? .present?
end end

View File

@ -80,8 +80,4 @@ after_initialize do
nil nil
end end
end end
on(:site_setting_changed) do |name, _old_value, _new_value|
LlmModel.seed_srv_backed_model if name == :ai_vllm_endpoint_srv || name == :ai_vllm_api_key
end
end end

View File

@ -5,5 +5,67 @@ Fabricator(:llm_model) do
name "gpt-4-turbo" name "gpt-4-turbo"
provider "open_ai" provider "open_ai"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer" tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
api_key "123"
url "https://api.openai.com/v1/chat/completions"
max_prompt_tokens 131_072
end
Fabricator(:anthropic_model, from: :llm_model) do
display_name "Claude 3 Opus"
name "claude-3-opus"
max_prompt_tokens 200_000
url "https://api.anthropic.com/v1/messages"
tokenizer "DiscourseAi::Tokenizer::AnthropicTokenizer"
provider "anthropic"
end
Fabricator(:hf_model, from: :llm_model) do
display_name "Llama 3.1"
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
max_prompt_tokens 64_000
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "https://test.dev/v1/chat/completions"
provider "hugging_face"
end
Fabricator(:vllm_model, from: :llm_model) do
display_name "Llama 3.1 vLLM"
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
max_prompt_tokens 64_000
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "https://test.dev/v1/chat/completions"
provider "vllm"
end
Fabricator(:fake_model, from: :llm_model) do
display_name "Fake model"
name "fake"
provider "fake"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
max_prompt_tokens 32_000 max_prompt_tokens 32_000
end end
Fabricator(:gemini_model, from: :llm_model) do
display_name "Gemini"
name "gemini-1.5-pro"
provider "google"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
max_prompt_tokens 800_000
url "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest"
end
Fabricator(:bedrock_model, from: :anthropic_model) do
url ""
provider "aws_bedrock"
api_key "asd-asd-asd"
name "claude-3-sonnet"
provider_params { { region: "us-east-1", access_key_id: "123456" } }
end
Fabricator(:cohere_model, from: :llm_model) do
display_name "Cohere Command R+"
name "command-r-plus"
provider "cohere"
api_key "ABC"
url "https://api.cohere.ai/v1/chat"
end

View File

@ -3,8 +3,8 @@
require_relative "dialect_context" require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
let(:model_name) { "gpt-4" } fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
let(:context) { DialectContext.new(described_class, model_name) } let(:context) { DialectContext.new(described_class, llm_model) }
describe "#translate" do describe "#translate" do
it "translates a prompt written in our generic format to the ChatGPT format" do it "translates a prompt written in our generic format to the ChatGPT format" do

View File

@ -2,9 +2,11 @@
RSpec.describe DiscourseAi::Completions::Dialects::Claude do RSpec.describe DiscourseAi::Completions::Dialects::Claude do
let :opus_dialect_klass do let :opus_dialect_klass do
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
end end
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
describe "#translate" do describe "#translate" do
it "can insert OKs to make stuff interleve properly" do it "can insert OKs to make stuff interleve properly" do
messages = [ messages = [
@ -17,7 +19,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialect = opus_dialect_klass.new(prompt, "claude-3-opus") dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate translated = dialect.translate
expected_messages = [ expected_messages = [
@ -62,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
tools: tools, tools: tools,
) )
dialect = opus_dialect_klass.new(prompt, "claude-3-opus") dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot") expect(translated.system_prompt).to start_with("You are a helpful bot")
@ -114,7 +116,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
messages: messages, messages: messages,
tools: tools, tools: tools,
) )
dialect = opus_dialect_klass.new(prompt, "claude-3-opus") dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot") expect(translated.system_prompt).to start_with("You are a helpful bot")

View File

@ -1,13 +1,13 @@
# frozen_string_literal: true # frozen_string_literal: true
class DialectContext class DialectContext
def initialize(dialect_klass, model_name) def initialize(dialect_klass, llm_model)
@dialect_klass = dialect_klass @dialect_klass = dialect_klass
@model_name = model_name @llm_model = llm_model
end end
def dialect(prompt) def dialect(prompt)
@dialect_klass.new(prompt, @model_name) @dialect_klass.new(prompt, @llm_model)
end end
def prompt def prompt

View File

@ -13,6 +13,8 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
end end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
fab!(:llm_model)
describe "#trim_messages" do describe "#trim_messages" do
let(:five_token_msg) { "This represents five tokens." } let(:five_token_msg) { "This represents five tokens." }
@ -23,7 +25,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
prompt.push(type: :tool, content: five_token_msg, id: 1) prompt.push(type: :tool, content: five_token_msg, id: 1)
prompt.push(type: :user, content: five_token_msg) prompt.push(type: :user, content: five_token_msg)
dialect = TestDialect.new(prompt, "test") dialect = TestDialect.new(prompt, llm_model)
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
trimmed = dialect.trim(prompt.messages) trimmed = dialect.trim(prompt.messages)
@ -37,7 +39,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens") prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens")
prompt.push(type: :user, content: five_token_msg) prompt.push(type: :user, content: five_token_msg)
dialect = TestDialect.new(prompt, "test") dialect = TestDialect.new(prompt, llm_model)
dialect.max_prompt_tokens = 15 dialect.max_prompt_tokens = 15
trimmed = dialect.trim(prompt.messages) trimmed = dialect.trim(prompt.messages)

View File

@ -3,8 +3,8 @@
require_relative "dialect_context" require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
let(:model_name) { "gemini-1.5-pro" } fab!(:model) { Fabricate(:gemini_model) }
let(:context) { DialectContext.new(described_class, model_name) } let(:context) { DialectContext.new(described_class, model) }
describe "#translate" do describe "#translate" do
it "translates a prompt written in our generic format to the Gemini format" do it "translates a prompt written in our generic format to the Gemini format" do
@ -86,11 +86,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
it "trims content if it's getting too long" do it "trims content if it's getting too long" do
# testing truncation on 800k tokens is slow use model with less # testing truncation on 800k tokens is slow use model with less
context = DialectContext.new(described_class, "gemini-pro") model.max_prompt_tokens = 16_384
context = DialectContext.new(described_class, model)
translated = context.long_user_input_scenario(length: 5_000) translated = context.long_user_input_scenario(length: 5_000)
expect(translated[:messages].last[:role]).to eq("user") expect(translated[:messages].last[:role]).to eq("user")
expect(translated[:messages].last.dig(:parts, :text).length).to be < expect(translated[:messages].last.dig(:parts, 0, :text).length).to be <
context.long_message_text(length: 5_000).length context.long_message_text(length: 5_000).length
end end
end end

View File

@ -3,16 +3,7 @@ require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:url) { "https://api.anthropic.com/v1/messages" } let(:url) { "https://api.anthropic.com/v1/messages" }
fab!(:model) do fab!(:model) { Fabricate(:anthropic_model, name: "claude-3-opus", vision_enabled: true) }
Fabricate(
:llm_model,
url: "https://api.anthropic.com/v1/messages",
name: "claude-3-opus",
provider: "anthropic",
api_key: "123",
vision_enabled: true,
)
end
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do let(:upload100x100) do
@ -204,6 +195,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
it "supports non streaming tool calls" do it "supports non streaming tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
tool = { tool = {
name: "calculate", name: "calculate",
description: "calculate something", description: "calculate something",
@ -224,8 +217,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
tools: [tool], tools: [tool],
) )
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
body = { body = {
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S", id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
type: "message", type: "message",
@ -252,7 +243,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
stub_request(:post, url).to_return(body: body) stub_request(:post, url).to_return(body: body)
result = proxy.generate(prompt, user: Discourse.system_user) result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip expected = <<~TEXT.strip
<function_calls> <function_calls>
@ -370,7 +361,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}, },
).to_return(status: 200, body: body) ).to_return(status: 200, body: body)
result = llm.generate(prompt, user: Discourse.system_user) proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
result = proxy.generate(prompt, user: Discourse.system_user)
expect(result).to eq("Hello!") expect(result).to eq("Hello!")
expected_body = { expected_body = {

View File

@ -8,9 +8,10 @@ class BedrockMock < EndpointMock
end end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) } subject(:endpoint) { described_class.new(model) }
fab!(:user) fab!(:user)
fab!(:model) { Fabricate(:bedrock_model) }
let(:bedrock_mock) { BedrockMock.new(endpoint) } let(:bedrock_mock) { BedrockMock.new(endpoint) }
@ -25,16 +26,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
Aws::EventStream::Encoder.new.encode(aws_message) Aws::EventStream::Encoder.new.encode(aws_message)
end end
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
describe "function calling" do describe "function calling" do
it "supports old school xml function calls" do it "supports old school xml function calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "" SiteSetting.ai_anthropic_native_tool_call_models = ""
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
incomplete_tool_call = <<~XML.strip incomplete_tool_call = <<~XML.strip
<thinking>I should be ignored</thinking> <thinking>I should be ignored</thinking>
@ -112,7 +107,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end end
it "supports streaming function calls" do it "supports streaming function calls" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil request = nil
@ -124,7 +119,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB", id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
type: "message", type: "message",
role: "assistant", role: "assistant",
model: "claude-3-haiku-20240307", model: "claude-3-sonnet-20240307",
stop_sequence: nil, stop_sequence: nil,
usage: { usage: {
input_tokens: 840, input_tokens: 840,
@ -281,9 +276,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end end
end end
describe "Claude 3 Sonnet support" do describe "Claude 3 support" do
it "supports the sonnet model" do it "supports regular completions" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil request = nil
@ -325,8 +320,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(log.response_tokens).to eq(20) expect(log.response_tokens).to eq(20)
end end
it "supports claude 3 sonnet streaming" do it "supports claude 3 streaming" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil request = nil

View File

@ -2,7 +2,8 @@
require_relative "endpoint_compliance" require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") } fab!(:cohere_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{cohere_model.id}") }
fab!(:user) fab!(:user)
let(:prompt) do let(:prompt) do
@ -57,8 +58,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
prompt prompt
end end
before { SiteSetting.ai_cohere_api_key = "ABC" }
it "is able to trigger a tool" do it "is able to trigger a tool" do
body = (<<~TEXT).strip body = (<<~TEXT).strip
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"} {"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
@ -184,7 +183,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
expect(audit.request_tokens).to eq(17) expect(audit.request_tokens).to eq(17)
expect(audit.response_tokens).to eq(22) expect(audit.response_tokens).to eq(22)
expect(audit.language_model).to eq("Cohere - command-r-plus") expect(audit.language_model).to eq("command-r-plus")
end end
it "is able to perform streaming completions" do it "is able to perform streaming completions" do

View File

@ -158,7 +158,7 @@ class EndpointsCompliance
end end
def dialect(prompt: generic_prompt) def dialect(prompt: generic_prompt)
dialect_klass.new(prompt, endpoint.model) dialect_klass.new(prompt, endpoint.llm_model)
end end
def regular_mode_simple_prompt(mock) def regular_mode_simple_prompt(mock)
@ -176,7 +176,7 @@ class EndpointsCompliance
expect(log.raw_request_payload).to be_present expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json) expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response)) expect(log.response_tokens).to eq(endpoint.llm_model.tokenizer_class.size(completion_response))
end end
def regular_mode_tools(mock) def regular_mode_tools(mock)
@ -206,7 +206,7 @@ class EndpointsCompliance
expect(log.raw_response_payload).to be_present expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq( expect(log.response_tokens).to eq(
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join), endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
) )
end end
end end

View File

@ -128,18 +128,9 @@ class GeminiMock < EndpointMock
end end
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) } subject(:endpoint) { described_class.new(model) }
fab!(:model) do fab!(:model) { Fabricate(:gemini_model, vision_enabled: true) }
Fabricate(
:llm_model,
url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest",
name: "gemini-1.5-pro",
provider: "google",
api_key: "ABC",
vision_enabled: true,
)
end
fab!(:user) fab!(:user)
@ -168,7 +159,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
req_body = nil req_body = nil
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:generateContent?key=ABC" url = "#{model.url}:generateContent?key=123"
stub_request(:post, url).with( stub_request(:post, url).with(
body: body:
@ -221,7 +212,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
split = data.split("|") split = data.split("|")
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=ABC" url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
output = +"" output = +""
gemini_mock.with_chunk_array_support do gemini_mock.with_chunk_array_support do

View File

@ -22,7 +22,7 @@ class HuggingFaceMock < EndpointMock
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: request_body(prompt)) .with(body: request_body(prompt))
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -40,7 +40,7 @@ class HuggingFaceMock < EndpointMock
end end
def stub_raw(chunks) def stub_raw(chunks)
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return( WebMock.stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200, status: 200,
body: chunks, body: chunks,
) )
@ -59,7 +59,7 @@ class HuggingFaceMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("") chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: request_body(prompt, stream: true)) .with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
@ -71,8 +71,7 @@ class HuggingFaceMock < EndpointMock
.default_options .default_options
.merge(messages: prompt) .merge(messages: prompt)
.tap do |b| .tap do |b|
b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - b[:max_tokens] = 63_991
model.prompt_size(prompt)
b[:stream] = true if stream b[:stream] = true if stream
end end
.to_json .to_json
@ -80,15 +79,9 @@ class HuggingFaceMock < EndpointMock
end end
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:endpoint) do subject(:endpoint) { described_class.new(hf_model) }
described_class.new(
"mistralai/Mistral-7B-Instruct-v0.2",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
fab!(:hf_model)
fab!(:user) fab!(:user)
let(:hf_mock) { HuggingFaceMock.new(endpoint) } let(:hf_mock) { HuggingFaceMock.new(endpoint) }

View File

@ -146,11 +146,10 @@ class OpenAiMock < EndpointMock
end end
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:endpoint) do subject(:endpoint) { described_class.new(model) }
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
end
fab!(:user) fab!(:user)
fab!(:model) { Fabricate(:llm_model) }
let(:echo_tool) do let(:echo_tool) do
{ {
@ -175,7 +174,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
describe "repeat calls" do describe "repeat calls" do
it "can properly reset context" do it "can properly reset context" do
llm = DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
tools = [ tools = [
{ {
@ -258,7 +257,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
describe "image support" do describe "image support" do
it "can handle images" do it "can handle images" do
model = Fabricate(:llm_model, provider: "open_ai", vision_enabled: true) model = Fabricate(:llm_model, vision_enabled: true)
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(

View File

@ -22,7 +22,7 @@ class VllmMock < EndpointMock
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") .stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt).to_json) .with(body: model.default_options.merge(messages: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -50,19 +50,16 @@ class VllmMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("") chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") .stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json) .with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
end end
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) do subject(:endpoint) { described_class.new(llm_model) }
described_class.new(
"mistralai/Mixtral-8x7B-Instruct-v0.1", fab!(:llm_model) { Fabricate(:vllm_model) }
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
fab!(:user) fab!(:user)
@ -78,15 +75,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
end end
let(:dialect) do let(:dialect) do
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, model_name) DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, llm_model)
end end
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
describe "#perform_completion!" do describe "#perform_completion!" do
context "when using regular mode" do context "when using regular mode" do
context "with simple prompts" do context "with simple prompts" do

View File

@ -5,12 +5,13 @@ RSpec.describe DiscourseAi::Completions::Llm do
described_class.new( described_class.new(
DiscourseAi::Completions::Dialects::OpenAiCompatible, DiscourseAi::Completions::Dialects::OpenAiCompatible,
canned_response, canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2", model,
gateway: canned_response, gateway: canned_response,
) )
end end
fab!(:user) fab!(:user)
fab!(:model) { Fabricate(:llm_model) }
describe ".proxy" do describe ".proxy" do
it "raises an exception when we can't proxy the model" do it "raises an exception when we can't proxy the model" do
@ -46,7 +47,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
) )
result = +"" result = +""
described_class described_class
.proxy("open_ai:gpt-3.5-turbo") .proxy("custom:#{model.id}")
.generate(prompt, user: user) { |partial| result << partial } .generate(prompt, user: user) { |partial| result << partial }
expect(result).to eq("Hello") expect(result).to eq("Hello")
@ -57,12 +58,14 @@ RSpec.describe DiscourseAi::Completions::Llm do
end end
describe "#generate with fake model" do describe "#generate with fake model" do
fab!(:fake_model)
before do before do
DiscourseAi::Completions::Endpoints::Fake.delays = [] DiscourseAi::Completions::Endpoints::Fake.delays = []
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10 DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
end end
let(:llm) { described_class.proxy("fake:fake") } let(:llm) { described_class.proxy("custom:#{fake_model.id}") }
let(:prompt) do let(:prompt) do
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(

View File

@ -5,6 +5,8 @@ return if !defined?(DiscourseAutomation)
describe DiscourseAutomation do describe DiscourseAutomation do
let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) } let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) }
fab!(:llm_model)
fab!(:user) fab!(:user)
fab!(:post) fab!(:post)
@ -22,7 +24,7 @@ describe DiscourseAutomation do
it "can trigger via automation" do it "can trigger via automation" do
add_automation_field("sender", user.username, type: "user") add_automation_field("sender", user.username, type: "user")
add_automation_field("receivers", [user.username], type: "users") add_automation_field("receivers", [user.username], type: "users")
add_automation_field("model", "gpt-4-turbo") add_automation_field("model", "custom:#{llm_model.id}")
add_automation_field("title", "Weekly report") add_automation_field("title", "Weekly report")
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
@ -36,7 +38,7 @@ describe DiscourseAutomation do
it "can target a topic" do it "can target a topic" do
add_automation_field("sender", user.username, type: "user") add_automation_field("sender", user.username, type: "user")
add_automation_field("topic_id", "#{post.topic_id}") add_automation_field("topic_id", "#{post.topic_id}")
add_automation_field("model", "gpt-4-turbo") add_automation_field("model", "custom:#{llm_model.id}")
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
automation.trigger! automation.trigger!

View File

@ -8,6 +8,8 @@ describe DiscourseAi::Automation::LlmTriage do
let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) } let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) }
fab!(:llm_model)
def add_automation_field(name, value, type: "text") def add_automation_field(name, value, type: "text")
automation.fields.create!( automation.fields.create!(
component: type, component: type,
@ -23,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
SiteSetting.tagging_enabled = true SiteSetting.tagging_enabled = true
add_automation_field("system_prompt", "hello %%POST%%") add_automation_field("system_prompt", "hello %%POST%%")
add_automation_field("search_for_text", "bad") add_automation_field("search_for_text", "bad")
add_automation_field("model", "gpt-4") add_automation_field("model", "custom:#{llm_model.id}")
add_automation_field("category", category.id, type: "category") add_automation_field("category", category.id, type: "category")
add_automation_field("tags", %w[aaa bbb], type: "tags") add_automation_field("tags", %w[aaa bbb], type: "tags")
add_automation_field("hide_topic", true, type: "boolean") add_automation_field("hide_topic", true, type: "boolean")

View File

@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
end end
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-4") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_4.name) }
let!(:user) { Fabricate(:user) } let!(:user) { Fabricate(:user) }
@ -38,7 +38,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
toggle_enabled_bots(bots: [fake]) toggle_enabled_bots(bots: [fake])
Group.refresh_automatic_groups! Group.refresh_automatic_groups!
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model("fake") bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model(fake.name)
AiPersona.create!( AiPersona.create!(
name: "TestPersona", name: "TestPersona",
top_p: 0.5, top_p: 0.5,

View File

@ -336,6 +336,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
context "when RAG is running with a question consolidator" do context "when RAG is running with a question consolidator" do
let(:consolidated_question) { "what is the time in france?" } let(:consolidated_question) { "what is the time in france?" }
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do it "will run the question consolidator" do
context_embedding = [0.049382, 0.9999] context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
@ -350,7 +352,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
name: "custom", name: "custom",
rag_conversation_chunks: 3, rag_conversation_chunks: 3,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
question_consolidator_llm: "fake:fake", question_consolidator_llm: "custom:#{llm_model.id}",
) )
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])

View File

@ -4,6 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) } subject(:playground) { described_class.new(bot) }
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") } fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") }
fab!(:opus_model) { Fabricate(:anthropic_model) }
fab!(:bot_user) do fab!(:bot_user) do
toggle_enabled_bots(bots: [claude_2]) toggle_enabled_bots(bots: [claude_2])
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
system_prompt: "You are a helpful bot", system_prompt: "You are a helpful bot",
vision_enabled: true, vision_enabled: true,
vision_max_pixels: 1_000, vision_max_pixels: 1_000,
default_llm: "anthropic:claude-3-opus", default_llm: "custom:#{opus_model.id}",
mentionable: true, mentionable: true,
) )
end end
@ -211,7 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
) )
persona.create_user! persona.create_user!
persona.update!(default_llm: "anthropic:claude-2", mentionable: true) persona.update!(default_llm: "custom:#{claude_2.id}", mentionable: true)
persona persona
end end
@ -228,7 +229,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}" SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
Group.refresh_automatic_groups! Group.refresh_automatic_groups!
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus") persona.update!(allow_chat: true, mentionable: true, default_llm: "custom:#{opus_model.id}")
end end
it "should behave in a sane way when threading is enabled" do it "should behave in a sane way when threading is enabled" do
@ -342,7 +343,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
persona.update!( persona.update!(
allow_chat: true, allow_chat: true,
mentionable: false, mentionable: false,
default_llm: "anthropic:claude-3-opus", default_llm: "custom:#{opus_model.id}",
) )
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
end end
@ -517,7 +518,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
DiscourseAi::Completions::Llm.with_prepared_responses( DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"], ["Magic title", "Yes I can"],
llm: "anthropic:claude-2", llm: "custom:#{claude_2.id}",
) do ) do
post = post =
create_post( create_post(
@ -552,7 +553,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# title is queued first, ensures it uses the llm targeted via target_usernames not claude # title is queued first, ensures it uses the llm targeted via target_usernames not claude
DiscourseAi::Completions::Llm.with_prepared_responses( DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"], ["Magic title", "Yes I can"],
llm: "open_ai:gpt-3.5-turbo", llm: "custom:#{gpt_35_turbo.id}",
) do ) do
post = post =
create_post( create_post(
@ -584,7 +585,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# replies as correct persona if replying direct to persona # replies as correct persona if replying direct to persona
DiscourseAi::Completions::Llm.with_prepared_responses( DiscourseAi::Completions::Llm.with_prepared_responses(
["Another reply"], ["Another reply"],
llm: "open_ai:gpt-3.5-turbo", llm: "custom:#{gpt_35_turbo.id}",
) do ) do
create_post( create_post(
raw: "Please ignore this bot, I am replying to a user", raw: "Please ignore this bot, I am replying to a user",

View File

@ -1,7 +1,7 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{Fabricate(:fake_model).id}") }
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake } let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
fab!(:user) fab!(:user)

View File

@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
SiteSetting.ai_openai_api_key = "abc" SiteSetting.ai_openai_api_key = "abc"
end end
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
let(:dall_e) do let(:dall_e) do

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }
describe "#process" do describe "#process" do

View File

@ -1,12 +1,10 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
before do before { SiteSetting.ai_bot_enabled = true }
SiteSetting.ai_bot_enabled = true
SiteSetting.ai_openai_api_key = "asd"
end
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read } let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }

View File

@ -3,7 +3,8 @@
require "rails_helper" require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) do let(:tool) do
described_class.new( described_class.new(

View File

@ -4,7 +4,8 @@ require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
let(:bot_user) { Fabricate(:user) } let(:bot_user) { Fabricate(:user) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) } let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
context "with #sort_and_shorten_diff" do context "with #sort_and_shorten_diff" do

View File

@ -4,7 +4,8 @@ require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
let(:bot_user) { Fabricate(:user) } let(:bot_user) { Fabricate(:user) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) } let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
context "with valid search results" do context "with valid search results" do

View File

@ -3,7 +3,8 @@
require "rails_helper" require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) do let(:tool) do
described_class.new( described_class.new(

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Google do RSpec.describe DiscourseAi::AiBot::Tools::Google do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) } let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }

View File

@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
describe "#process" do describe "#process" do
it "can generate correct info" do it "can generate correct info" do

View File

@ -1,8 +1,9 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do before do
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Read do RSpec.describe DiscourseAi::AiBot::Tools::Read do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) } let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
fab!(:parent_category) { Fabricate(:category, name: "animals") } fab!(:parent_category) { Fabricate(:category, name: "animals") }

View File

@ -1,13 +1,13 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
fab!(:gpt_35_bot) { Fabricate(:llm_model, name: "gpt-3.5-turbo") } fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do before do
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
toggle_enabled_bots(bots: [gpt_35_bot]) toggle_enabled_bots(bots: [llm_model])
end end
def search_settings(query) def search_settings(query)

View File

@ -4,10 +4,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
before { SearchIndexer.enable } before { SearchIndexer.enable }
after { SearchIndexer.disable } after { SearchIndexer.disable }
before { SiteSetting.ai_openai_api_key = "asd" } fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
fab!(:admin) fab!(:admin)

View File

@ -9,8 +9,10 @@ def has_rg?
end end
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} } let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true #frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Time do RSpec.describe DiscourseAi::AiBot::Tools::Time do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true } before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,13 +1,11 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") } fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") } let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do before { SiteSetting.ai_bot_enabled = true }
SiteSetting.ai_openai_api_key = "asd"
SiteSetting.ai_bot_enabled = true
end
describe "#invoke" do describe "#invoke" do
it "can retrieve the content of a webpage and returns the processed text" do it "can retrieve the content of a webpage and returns the processed text" do

View File

@ -1,6 +1,7 @@
# frozen_string_literal: true # frozen_string_literal: true
describe DiscourseAi::Automation::LlmTriage do describe DiscourseAi::Automation::LlmTriage do
fab!(:post) fab!(:post)
fab!(:llm_model)
def triage(**args) def triage(**args)
DiscourseAi::Automation::LlmTriage.handle(**args) DiscourseAi::Automation::LlmTriage.handle(**args)
@ -10,7 +11,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
hide_topic: true, hide_topic: true,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -24,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
hide_topic: true, hide_topic: true,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -40,7 +41,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
category_id: category.id, category_id: category.id,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -55,7 +56,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
canned_reply: "test canned reply 123", canned_reply: "test canned reply 123",
@ -73,7 +74,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
flag_post: true, flag_post: true,
@ -89,7 +90,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
triage( triage(
post: post, post: post,
model: "gpt-4", model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
flag_post: true, flag_post: true,

View File

@ -32,6 +32,8 @@ module DiscourseAi
fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) } fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) }
fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) } fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) }
fab!(:llm_model)
describe "#run!" do describe "#run!" do
it "is able to generate email reports" do it "is able to generate email reports" do
freeze_time freeze_time
@ -41,7 +43,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: ["fake@discourse.com"], receivers: ["fake@discourse.com"],
title: "test report %DATE%", title: "test report %DATE%",
model: "gpt-4", model: "custom:#{llm_model.id}",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,
@ -78,7 +80,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: [receiver.username], receivers: [receiver.username],
title: "test report", title: "test report",
model: "gpt-4", model: "custom:#{llm_model.id}",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,
@ -123,7 +125,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: [receiver.username], receivers: [receiver.username],
title: "test report", title: "test report",
model: "gpt-4", model: "custom:#{llm_model.id}",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,
@ -166,7 +168,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: [receiver.username], receivers: [receiver.username],
title: "test report", title: "test report",
model: "gpt-4", model: "custom:#{llm_model.id}",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,
@ -194,7 +196,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: [receiver.username], receivers: [receiver.username],
title: "test report", title: "test report",
model: "gpt-4", model: "custom:#{llm_model.id}",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,

View File

@ -2,7 +2,7 @@
RSpec.describe AiTool do RSpec.describe AiTool do
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") } fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
let(:llm) { DiscourseAi::Completions::Llm.proxy_from_obj(llm_model) } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
def create_tool(parameters: nil, script: nil) def create_tool(parameters: nil, script: nil)
AiTool.create!( AiTool.create!(

View File

@ -1,64 +0,0 @@
# frozen_string_literal: true
RSpec.describe LlmModel do
describe ".seed_srv_backed_model" do
before do
SiteSetting.ai_vllm_endpoint_srv = "srv.llm.service."
SiteSetting.ai_vllm_api_key = "123"
end
context "when the model doesn't exist yet" do
it "creates it" do
described_class.seed_srv_backed_model
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model).to be_present
expect(llm_model.name).to eq("mistralai/Mixtral-8x7B-Instruct-v0.1")
expect(llm_model.api_key).to eq(SiteSetting.ai_vllm_api_key)
end
end
context "when the model already exists" do
before { described_class.seed_srv_backed_model }
context "when the API key setting changes" do
it "updates it" do
new_key = "456"
SiteSetting.ai_vllm_api_key = new_key
described_class.seed_srv_backed_model
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model.api_key).to eq(new_key)
end
end
context "when the SRV is no longer defined" do
it "deletes the LlmModel" do
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model).to be_present
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
expect { llm_model.reload }.to raise_exception(ActiveRecord::RecordNotFound)
end
it "disabled the bot user" do
SiteSetting.ai_bot_enabled = true
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
llm_model.update!(enabled_chat_bot: true)
llm_model.toggle_companion_user
user = llm_model.user
expect(user).to be_present
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
expect { user.reload }.to raise_exception(ActiveRecord::RecordNotFound)
end
end
end
end
end

View File

@ -8,7 +8,7 @@ module DiscourseAi::ChatBotHelper
end end
def assign_fake_provider_to(setting_name) def assign_fake_provider_to(setting_name)
Fabricate(:llm_model, provider: "fake", name: "fake").tap do |fake_llm| Fabricate(:fake_model).tap do |fake_llm|
SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}") SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}")
end end
end end