DEV: Remove old code now that features rely on LlmModels. (#729)

* DEV: Remove old code now that features rely on LlmModels.

* Hide old settings and migrate persona llm overrides

* Remove shadowing special URL + seeding code. Use srv:// prefix instead.
This commit is contained in:
Roman Rizzi 2024-07-30 13:44:57 -03:00 committed by GitHub
parent 73a2b15e91
commit bed044448c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 439 additions and 813 deletions

View File

@ -110,9 +110,7 @@ module DiscourseAi
)
provider = updating ? updating.provider : permitted[:provider]
permit_url =
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
provider != LlmModel::BEDROCK_PROVIDER_NAME
permit_url = provider != LlmModel::BEDROCK_PROVIDER_NAME
permitted[:url] = params.dig(:ai_llm, :url) if permit_url

View File

@ -2,44 +2,10 @@
class LlmModel < ActiveRecord::Base
FIRST_BOT_USER_ID = -1200
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
BEDROCK_PROVIDER_NAME = "aws_bedrock"
belongs_to :user
validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] }, if: :url_changed?
def self.seed_srv_backed_model
srv = SiteSetting.ai_vllm_endpoint_srv
srv_model = find_by(url: RESERVED_VLLM_SRV_URL)
if srv.present?
if srv_model.present?
current_key = SiteSetting.ai_vllm_api_key
srv_model.update!(api_key: current_key) if current_key != srv_model.api_key
else
record =
new(
display_name: "vLLM SRV LLM",
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
provider: "vllm",
tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer",
url: RESERVED_VLLM_SRV_URL,
max_prompt_tokens: 8000,
api_key: SiteSetting.ai_vllm_api_key,
)
record.save(validate: false) # Ignore reserved URL validation
end
else
# Clean up companion users
srv_model&.enabled_chat_bot = false
srv_model&.toggle_companion_user
srv_model&.destroy!
end
end
def self.provider_params
{
aws_bedrock: {
@ -54,7 +20,7 @@ class LlmModel < ActiveRecord::Base
end
def to_llm
DiscourseAi::Completions::Llm.proxy_from_obj(self)
DiscourseAi::Completions::Llm.proxy("custom:#{id}")
end
def toggle_companion_user

View File

@ -19,6 +19,6 @@ class LlmModelSerializer < ApplicationSerializer
has_one :user, serializer: BasicUserSerializer, embed: :object
def shadowed_by_srv
object.url == LlmModel::RESERVED_VLLM_SRV_URL
object.url.to_s.starts_with?("srv://")
end
end

View File

@ -60,17 +60,9 @@ export default class AiLlmEditorForm extends Component {
return this.testRunning || this.testResult !== null;
}
get displaySRVWarning() {
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
}
get canEditURL() {
// Explicitly false.
if (this.metaProviderParams.url_editable === false) {
return false;
}
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
return this.metaProviderParams.url_editable !== false;
}
@computed("args.model.provider")
@ -174,12 +166,6 @@ export default class AiLlmEditorForm extends Component {
}
<template>
{{#if this.displaySRVWarning}}
<div class="alert alert-info">
{{icon "exclamation-circle"}}
{{I18n.t "discourse_ai.llms.srv_warning"}}
</div>
{{/if}}
<form class="form-horizontal ai-llm-editor">
<div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label>

View File

@ -237,7 +237,6 @@ en:
confirm_delete: Are you sure you want to delete this model?
delete: Delete
srv_warning: This LLM points to an SRV record, and its URL is not editable. You have to update the hidden "ai_vllm_endpoint_srv" setting instead.
preconfigured_llms: "Select your LLM"
preconfigured:
none: "Configure manually..."

View File

@ -96,21 +96,35 @@ discourse_ai:
- opennsfw2
- nsfw_detector
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt35_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt35_16k_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4o_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_32k_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_gpt4_turbo_url:
default: "https://api.openai.com/v1/chat/completions"
hidden: true
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
ai_openai_organization: ""
ai_openai_organization:
default: ""
hidden: true
ai_openai_api_key:
default: ""
secret: true
ai_anthropic_api_key:
default: ""
secret: true
hidden: true
ai_anthropic_native_tool_call_models:
type: list
list_type: compact
@ -123,7 +137,7 @@ discourse_ai:
- claude-3-5-sonnet
ai_cohere_api_key:
default: ""
secret: true
hidden: true
ai_stability_api_key:
default: ""
secret: true
@ -140,13 +154,16 @@ discourse_ai:
- "stable-diffusion-v1-5"
ai_hugging_face_api_url:
default: ""
hidden: true
ai_hugging_face_api_key:
default: ""
secret: true
hidden: true
ai_hugging_face_token_limit:
default: 4096
hidden: true
ai_hugging_face_model_display_name:
default: ""
hidden: true
ai_hugging_face_tei_endpoint:
default: ""
ai_hugging_face_tei_endpoint_srv:
@ -167,11 +184,13 @@ discourse_ai:
ai_bedrock_access_key_id:
default: ""
secret: true
hidden: true
ai_bedrock_secret_access_key:
default: ""
secret: true
hidden: true
ai_bedrock_region:
default: "us-east-1"
hidden: true
ai_cloudflare_workers_account_id:
default: ""
secret: true
@ -180,13 +199,16 @@ discourse_ai:
secret: true
ai_gemini_api_key:
default: ""
secret: true
hidden: true
ai_vllm_endpoint:
default: ""
hidden: true
ai_vllm_endpoint_srv:
default: ""
hidden: true
ai_vllm_api_key: ""
ai_vllm_api_key:
default: ""
hidden: true
ai_llava_endpoint:
default: ""
hidden: true

View File

@ -1,8 +0,0 @@
# frozen_string_literal: true
begin
LlmModel.seed_srv_backed_model
rescue PG::UndefinedColumn => e
# If this code runs before migrations, an attribute might be missing.
Rails.logger.warn("Failed to seed SRV-Backed LLM: #{e.meesage}")
end

View File

@ -25,8 +25,9 @@ class MigrateVisionLlms < ActiveRecord::Migration[7.1]
).first
if current_value && current_value != "llava"
model_name = current_value.split(":").last
llm_model =
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: current_value).first
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: model_name).first
if llm_model
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
class MigratePersonaLlmOverride < ActiveRecord::Migration[7.1]
def up
fields_to_update = DB.query(<<~SQL)
SELECT id, default_llm
FROM ai_personas
WHERE default_llm IS NOT NULL
SQL
return if fields_to_update.empty?
updated_fields =
fields_to_update
.map do |field|
llm_model_id = matching_llm_model(field.default_llm)
"(#{field.id}, 'custom:#{llm_model_id}')" if llm_model_id
end
.compact
return if updated_fields.empty?
DB.exec(<<~SQL)
UPDATE ai_personas
SET default_llm = new_fields.new_default_llm
FROM (VALUES #{updated_fields.join(", ")}) AS new_fields(id, new_default_llm)
WHERE new_fields.id::bigint = ai_personas.id
SQL
end
def matching_llm_model(model)
provider = model.split(":").first
model_name = model.split(":").last
return if provider == "custom"
DB.query_single(
"SELECT id FROM llm_models WHERE name = :name AND provider = :provider",
{ name: model_name, provider: provider },
).first
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -5,19 +5,15 @@ module DiscourseAi
module Dialects
class ChatGpt < Dialect
class << self
def can_translate?(model_name)
model_name.starts_with?("gpt-")
def can_translate?(model_provider)
model_provider == "open_ai" || model_provider == "azure"
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def native_tool_support?
true
llm_model.provider == "open_ai" || llm_model.provider == "azure"
end
def translate
@ -30,19 +26,17 @@ module DiscourseAi
end
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
@function_size ||= llm_model.tokenizer_class.size(tools.to_json.to_s)
buffer += @function_size
end
model_max_tokens - buffer
llm_model.max_prompt_tokens - buffer
end
private
@ -105,24 +99,7 @@ module DiscourseAi
end
def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def model_max_tokens
case model_name
when "gpt-3.5-turbo-16k"
16_384
when "gpt-4"
8192
when "gpt-4-32k"
32_768
when "gpt-4-turbo"
131_072
when "gpt-4o"
131_072
else
8192
end
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end
end
end

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects
class Claude < Dialect
class << self
def can_translate?(model_name)
model_name.start_with?("claude") || model_name.start_with?("anthropic")
def can_translate?(provider_name)
provider_name == "anthropic" || provider_name == "aws_bedrock"
end
end
@ -26,10 +26,6 @@ module DiscourseAi
end
end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
end
def translate
messages = super
@ -61,14 +57,11 @@ module DiscourseAi
end
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
llm_model.max_prompt_tokens
end
def native_tool_support?
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name)
end
private

View File

@ -6,18 +6,12 @@ module DiscourseAi
module Completions
module Dialects
class Command < Dialect
class << self
def can_translate?(model_name)
%w[command-light command command-r command-r-plus].include?(model_name)
end
def self.can_translate?(model_provider)
model_provider == "cohere"
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def translate
messages = super
@ -68,20 +62,7 @@ module DiscourseAi
end
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
case model_name
when "command-light"
4096
when "command"
8192
when "command-r"
131_072
when "command-r-plus"
131_072
else
8192
end
llm_model.max_prompt_tokens
end
def native_tool_support?
@ -99,7 +80,7 @@ module DiscourseAi
end
def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end
def system_msg(msg)

View File

@ -5,7 +5,7 @@ module DiscourseAi
module Dialects
class Dialect
class << self
def can_translate?(_model_name)
def can_translate?(model_provider)
raise NotImplemented
end
@ -19,7 +19,7 @@ module DiscourseAi
]
end
def dialect_for(model_name)
def dialect_for(model_provider)
dialects = []
if Rails.env.test? || Rails.env.development?
@ -28,26 +28,21 @@ module DiscourseAi
dialects = dialects.concat(all_dialects)
dialect = dialects.find { |d| d.can_translate?(model_name) }
dialect = dialects.find { |d| d.can_translate?(model_provider) }
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
dialect
end
end
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
def initialize(generic_prompt, llm_model, opts: {})
@prompt = generic_prompt
@model_name = model_name
@opts = opts
@llm_model = llm_model
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
raise NotImplemented
end
def can_end_with_assistant_msg?
false
end
@ -57,7 +52,7 @@ module DiscourseAi
end
def vision_support?
llm_model&.vision_enabled?
llm_model.vision_enabled?
end
def tools
@ -88,12 +83,12 @@ module DiscourseAi
private
attr_reader :model_name, :opts, :llm_model
attr_reader :opts, :llm_model
def trim_messages(messages)
prompt_limit = max_prompt_tokens
current_token_count = 0
message_step_size = (max_prompt_tokens / 25).to_i * -1
message_step_size = (prompt_limit / 25).to_i * -1
trimmed_messages = []
@ -157,7 +152,7 @@ module DiscourseAi
end
def calculate_message_token(msg)
self.tokenizer.size(msg[:content].to_s)
llm_model.tokenizer_class.size(msg[:content].to_s)
end
def tools_dialect

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects
class Gemini < Dialect
class << self
def can_translate?(model_name)
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
def can_translate?(model_provider)
model_provider == "google"
end
end
@ -14,10 +14,6 @@ module DiscourseAi
true
end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
@ -74,24 +70,17 @@ module DiscourseAi
end
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
if model_name.start_with?("gemini-1.5")
# technically we support 1 million tokens, but we're being conservative
800_000
else
16_384 # 50% of model tokens
end
llm_model.max_prompt_tokens
end
protected
def calculate_message_token(context)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
end
def beta_api?
@beta_api ||= model_name.start_with?("gemini-1.5")
@beta_api ||= llm_model.name.start_with?("gemini-1.5")
end
def system_msg(msg)

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Anthropic < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "anthropic"
end
def dependant_setting_names
%w[ai_anthropic_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_anthropic_api_key.present?
end
def endpoint_name(model_name)
"Anthropic - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "anthropic"
end
def normalize_model_params(model_params)
@ -29,7 +15,7 @@ module DiscourseAi
def default_options(dialect)
mapped_model =
case model
case llm_model.name
when "claude-2"
"claude-2.1"
when "claude-instant-1"
@ -43,7 +29,7 @@ module DiscourseAi
when "claude-3-5-sonnet"
"claude-3-5-sonnet-20240620"
else
model
llm_model.name
end
options = { model: mapped_model, max_tokens: 3_000 }
@ -74,9 +60,7 @@ module DiscourseAi
end
def model_uri
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
URI(url)
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, dialect)
@ -94,7 +78,7 @@ module DiscourseAi
def prepare_request(payload)
headers = {
"anthropic-version" => "2023-06-01",
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
"x-api-key" => llm_model.api_key,
"content-type" => "application/json",
}

View File

@ -6,24 +6,8 @@ module DiscourseAi
module Completions
module Endpoints
class AwsBedrock < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "aws_bedrock"
end
def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end
def correctly_configured?(_model)
SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
end
def endpoint_name(model_name)
"AWS Bedrock - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "aws_bedrock"
end
def normalize_model_params(model_params)
@ -62,18 +46,10 @@ module DiscourseAi
end
def model_uri
if llm_model
region = llm_model.lookup_custom_param("region")
api_url =
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
else
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
#
# FYI there is a 2.0 version of Claude, very little need to support it given
# haiku/sonnet are better fits anyway, we map to claude-2.1
bedrock_model_id =
case model
case llm_model.name
when "claude-2"
"anthropic.claude-v2:1"
when "claude-3-haiku"
@ -87,12 +63,11 @@ module DiscourseAi
when "claude-3-5-sonnet"
"anthropic.claude-3-5-sonnet-20240620-v1:0"
else
model
llm_model.name
end
api_url =
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
end
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -114,11 +89,9 @@ module DiscourseAi
signer =
Aws::Sigv4::Signer.new(
access_key_id:
llm_model&.lookup_custom_param("access_key_id") ||
SiteSetting.ai_bedrock_access_key_id,
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
access_key_id: llm_model.lookup_custom_param("access_key_id"),
region: llm_model.lookup_custom_param("region"),
secret_access_key: llm_model.api_key,
service: "bedrock",
)

View File

@ -30,39 +30,12 @@ module DiscourseAi
end
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.llm.endpoints.configuration_hint",
settings: settings.join(", "),
count: settings.length,
)
end
def display_name(model_name)
to_display = endpoint_name(model_name)
return to_display if correctly_configured?(model_name)
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
end
def dependant_setting_names
raise NotImplementedError
end
def endpoint_name(_model_name)
raise NotImplementedError
end
def can_contact?(_endpoint_name)
def can_contact?(_model_provider)
raise NotImplementedError
end
end
def initialize(model_name, tokenizer, llm_model: nil)
@model = model_name
@tokenizer = tokenizer
def initialize(llm_model)
@llm_model = llm_model
end
@ -136,7 +109,7 @@ module DiscourseAi
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: self.class.endpoint_name(@model),
language_model: llm_model.name,
)
if !@streaming_mode
@ -323,10 +296,14 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt))
end
attr_reader :tokenizer, :model, :llm_model
attr_reader :llm_model
protected
def tokenizer
llm_model.tokenizer_class
end
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError

View File

@ -6,10 +6,6 @@ module DiscourseAi
class CannedResponse
CANNED_RESPONSE_ERROR = Class.new(StandardError)
def self.can_contact?(_)
Rails.env.test?
end
def initialize(responses)
@responses = responses
@completions = 0

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Cohere < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "cohere"
end
def dependant_setting_names
%w[ai_cohere_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_cohere_api_key.present?
end
def endpoint_name(model_name)
"Cohere - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "cohere"
end
def normalize_model_params(model_params)
@ -39,9 +25,7 @@ module DiscourseAi
private
def model_uri
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
URI(url)
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, dialect)
@ -59,7 +43,7 @@ module DiscourseAi
def prepare_request(payload)
headers = {
"Content-Type" => "application/json",
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
"Authorization" => "Bearer #{llm_model.api_key}",
}
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -4,20 +4,6 @@ module DiscourseAi
module Completions
module Endpoints
class Fake < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "fake"
end
def correctly_configured?(_model_name)
true
end
def endpoint_name(_model_name)
"Test - fake model"
end
end
STOCK_CONTENT = <<~TEXT
# Discourse Markdown Styles Showcase
@ -75,6 +61,10 @@ module DiscourseAi
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
TEXT
def self.can_contact?(model_provider)
model_provider == "fake"
end
def self.with_fake_content(content)
@fake_content = content
yield

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Gemini < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "google"
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_gemini_api_key.present?
end
def endpoint_name(model_name)
"Google - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "google"
end
def default_options
@ -59,21 +45,8 @@ module DiscourseAi
private
def model_uri
if llm_model
url = llm_model.url
else
mapped_model = model
if model == "gemini-1.5-pro"
mapped_model = "gemini-1.5-pro-latest"
elsif model == "gemini-1.5-flash"
mapped_model = "gemini-1.5-flash-latest"
elsif model == "gemini-1.0-pro"
mapped_model = "gemini-pro-latest"
end
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
end
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
key = llm_model.api_key
if @streaming_mode
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class HuggingFace < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "hugging_face"
end
def dependant_setting_names
%w[ai_hugging_face_api_url]
end
def correctly_configured?(_model_name)
SiteSetting.ai_hugging_face_api_url.present?
end
def endpoint_name(model_name)
"Hugging Face - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "hugging_face"
end
def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end
def default_options
{ model: model, temperature: 0.7 }
{ model: llm_model.name, temperature: 0.7 }
end
def provider_id
@ -44,7 +30,7 @@ module DiscourseAi
private
def model_uri
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, _dialect)
@ -53,8 +39,7 @@ module DiscourseAi
.merge(messages: prompt)
.tap do |payload|
if !payload[:max_tokens]
token_limit =
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
token_limit = llm_model.max_prompt_tokens
payload[:max_tokens] = token_limit - prompt_size(prompt)
end
@ -64,7 +49,7 @@ module DiscourseAi
end
def prepare_request(payload)
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
api_key = llm_model.api_key
headers =
{ "Content-Type" => "application/json" }.tap do |h|

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Ollama < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "ollama"
end
def dependant_setting_names
%w[ai_ollama_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_ollama_endpoint.present?
end
def endpoint_name(model_name)
"Ollama - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "ollama"
end
def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end
def default_options
{ max_tokens: 2000, model: model }
{ max_tokens: 2000, model: llm_model.name }
end
def provider_id
@ -48,7 +34,7 @@ module DiscourseAi
private
def model_uri
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, _dialect)

View File

@ -4,56 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class OpenAi < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "open_ai"
end
def dependant_setting_names
%w[
ai_openai_api_key
ai_openai_gpt4o_url
ai_openai_gpt4_32k_url
ai_openai_gpt4_turbo_url
ai_openai_gpt4_url
ai_openai_gpt4_url
ai_openai_gpt35_16k_url
ai_openai_gpt35_url
]
end
def correctly_configured?(model_name)
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
end
def has_url?(model)
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
if model.include?("1106") || model.include?("turbo")
SiteSetting.ai_openai_gpt4_turbo_url
elsif model.include?("gpt-4o")
SiteSetting.ai_openai_gpt4o_url
else
SiteSetting.ai_openai_gpt4_url
end
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
url.present?
end
def endpoint_name(model_name)
"OpenAI - #{model_name}"
end
def self.can_contact?(model_provider)
%w[open_ai azure].include?(model_provider)
end
def normalize_model_params(model_params)
@ -68,7 +20,7 @@ module DiscourseAi
end
def default_options
{ model: model }
{ model: llm_model.name }
end
def provider_id
@ -78,28 +30,7 @@ module DiscourseAi
private
def model_uri
return URI(llm_model.url) if llm_model&.url
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
if model.include?("1106") || model.include?("turbo")
SiteSetting.ai_openai_gpt4_turbo_url
else
SiteSetting.ai_openai_gpt4_url
end
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
URI(url)
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, dialect)
@ -110,7 +41,7 @@ module DiscourseAi
# Usage is not available in Azure yet.
# We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if model_uri.host.exclude?("azure")
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end
payload[:tools] = dialect.tools if dialect.tools.present?
@ -119,18 +50,15 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
api_key = llm_model.api_key
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
if model_uri.host.include?("azure")
if llm_model.provider == "azure"
headers["api-key"] = api_key
else
headers["Authorization"] = "Bearer #{api_key}"
end
org_id =
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
org_id = llm_model.lookup_custom_param("organization")
headers["OpenAI-Organization"] = org_id if org_id.present?
end
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -4,22 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Vllm < Base
class << self
def can_contact?(endpoint_name)
endpoint_name == "vllm"
end
def dependant_setting_names
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
end
def endpoint_name(model_name)
"vLLM - #{model_name}"
end
def self.can_contact?(model_provider)
model_provider == "vllm"
end
def normalize_model_params(model_params)
@ -34,7 +20,7 @@ module DiscourseAi
end
def default_options
{ max_tokens: 2000, model: model }
{ max_tokens: 2000, model: llm_model.name }
end
def provider_id
@ -44,16 +30,13 @@ module DiscourseAi
private
def model_uri
if llm_model&.url && !llm_model&.url == LlmModel::RESERVED_VLLM_SRV_URL
return URI(llm_model.url)
end
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
if llm_model.url.to_s.starts_with?("srv://")
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
api_endpoint = llm_model.url
end
@uri ||= URI(api_endpoint)
end

View File

@ -64,8 +64,8 @@ module DiscourseAi
id: "open_ai",
models: [
{ name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" },
{ name: "gpt-4o-mini", tokens: 131_072, display_name: "GPT-4 Omni Mini" },
{ name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" },
{ name: "gpt-3.5-turbo", tokens: 16_385, display_name: "GPT-3.5 Turbo" },
],
tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer,
endpoint: "https://api.openai.com/v1/chat/completions",
@ -89,41 +89,6 @@ module DiscourseAi
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
end
def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them.
@models_by_provider ||=
{
aws_bedrock: %w[
claude-instant-1
claude-2
claude-3-haiku
claude-3-sonnet
claude-3-opus
],
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
],
cohere: %w[command-light command command-r command-r-plus],
open_ai: %w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
gpt-4o
],
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
}.tap do |h|
h[:ollama] = ["mistral"] if Rails.env.development?
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
end
end
def valid_provider_models
return @valid_provider_models if defined?(@valid_provider_models)
@ -151,61 +116,38 @@ module DiscourseAi
@prompts << prompt if @prompts
end
def proxy(model_name)
provider_and_model_name = model_name.split(":")
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
def proxy(model)
llm_model =
if model.is_a?(LlmModel)
model
else
model_name_without_prov = model.split(":").last.to_i
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
if provider_name == "custom"
llm_model = LlmModel.find(model_name_without_prov)
raise UNKNOWN_MODEL if !llm_model
return proxy_from_obj(llm_model)
LlmModel.find_by(id: model_name_without_prov)
end
dialect_klass =
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
raise UNKNOWN_MODEL if llm_model.nil?
model_provider = llm_model.provider
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
if @canned_response
if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
if @canned_llm && @canned_llm != model
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model}"
end
return new(dialect_klass, nil, model_name, gateway: @canned_response)
return new(dialect_klass, nil, llm_model, gateway: @canned_response)
end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
new(dialect_klass, gateway_klass, model_name_without_prov)
end
def proxy_from_obj(llm_model)
provider_name = llm_model.provider
model_name = llm_model.name
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
if @canned_response
if @canned_llm && @canned_llm != [provider_name, model_name].join(":")
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end
return(
new(dialect_klass, nil, model_name, gateway: @canned_response, llm_model: llm_model)
)
end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
new(dialect_klass, gateway_klass, llm_model)
end
end
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
def initialize(dialect_klass, gateway_klass, llm_model, gateway: nil)
@dialect_klass = dialect_klass
@gateway_klass = gateway_klass
@model_name = model_name
@gateway = gateway
@llm_model = llm_model
end
@ -264,9 +206,9 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
dialect = dialect_klass.new(prompt, llm_model, opts: model_params)
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
gateway = @gateway || gateway_klass.new(llm_model)
gateway.perform_completion!(
dialect,
user,
@ -277,16 +219,14 @@ module DiscourseAi
end
def max_prompt_tokens
llm_model&.max_prompt_tokens ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
llm_model.max_prompt_tokens
end
def tokenizer
llm_model&.tokenizer_class ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
llm_model.tokenizer_class
end
attr_reader :model_name, :llm_model
attr_reader :llm_model
private

View File

@ -15,19 +15,16 @@ module DiscourseAi
return !@parent_enabled
end
llm_model_id = val.split(":")&.last
llm_model = LlmModel.find_by(id: llm_model_id)
return false if llm_model.nil?
run_test(llm_model).tap { |result| @unreachable = result }
rescue StandardError
run_test(val).tap { |result| @unreachable = result }
rescue StandardError => e
raise e if Rails.env.test?
@unreachable = true
false
end
def run_test(llm_model)
def run_test(val)
DiscourseAi::Completions::Llm
.proxy_from_obj(llm_model)
.proxy(val)
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
.present?
end

View File

@ -80,8 +80,4 @@ after_initialize do
nil
end
end
on(:site_setting_changed) do |name, _old_value, _new_value|
LlmModel.seed_srv_backed_model if name == :ai_vllm_endpoint_srv || name == :ai_vllm_api_key
end
end

View File

@ -5,5 +5,67 @@ Fabricator(:llm_model) do
name "gpt-4-turbo"
provider "open_ai"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
api_key "123"
url "https://api.openai.com/v1/chat/completions"
max_prompt_tokens 131_072
end
Fabricator(:anthropic_model, from: :llm_model) do
display_name "Claude 3 Opus"
name "claude-3-opus"
max_prompt_tokens 200_000
url "https://api.anthropic.com/v1/messages"
tokenizer "DiscourseAi::Tokenizer::AnthropicTokenizer"
provider "anthropic"
end
Fabricator(:hf_model, from: :llm_model) do
display_name "Llama 3.1"
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
max_prompt_tokens 64_000
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "https://test.dev/v1/chat/completions"
provider "hugging_face"
end
Fabricator(:vllm_model, from: :llm_model) do
display_name "Llama 3.1 vLLM"
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
max_prompt_tokens 64_000
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "https://test.dev/v1/chat/completions"
provider "vllm"
end
Fabricator(:fake_model, from: :llm_model) do
display_name "Fake model"
name "fake"
provider "fake"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
max_prompt_tokens 32_000
end
Fabricator(:gemini_model, from: :llm_model) do
display_name "Gemini"
name "gemini-1.5-pro"
provider "google"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
max_prompt_tokens 800_000
url "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest"
end
Fabricator(:bedrock_model, from: :anthropic_model) do
url ""
provider "aws_bedrock"
api_key "asd-asd-asd"
name "claude-3-sonnet"
provider_params { { region: "us-east-1", access_key_id: "123456" } }
end
Fabricator(:cohere_model, from: :llm_model) do
display_name "Cohere Command R+"
name "command-r-plus"
provider "cohere"
api_key "ABC"
url "https://api.cohere.ai/v1/chat"
end

View File

@ -3,8 +3,8 @@
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
let(:model_name) { "gpt-4" }
let(:context) { DialectContext.new(described_class, model_name) }
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
let(:context) { DialectContext.new(described_class, llm_model) }
describe "#translate" do
it "translates a prompt written in our generic format to the ChatGPT format" do

View File

@ -2,9 +2,11 @@
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
let :opus_dialect_klass do
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
end
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
describe "#translate" do
it "can insert OKs to make stuff interleve properly" do
messages = [
@ -17,7 +19,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expected_messages = [
@ -62,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
tools: tools,
)
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")
@ -114,7 +116,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
messages: messages,
tools: tools,
)
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
dialect = opus_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")

View File

@ -1,13 +1,13 @@
# frozen_string_literal: true
class DialectContext
def initialize(dialect_klass, model_name)
def initialize(dialect_klass, llm_model)
@dialect_klass = dialect_klass
@model_name = model_name
@llm_model = llm_model
end
def dialect(prompt)
@dialect_klass.new(prompt, @model_name)
@dialect_klass.new(prompt, @llm_model)
end
def prompt

View File

@ -13,6 +13,8 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
fab!(:llm_model)
describe "#trim_messages" do
let(:five_token_msg) { "This represents five tokens." }
@ -23,7 +25,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
prompt.push(type: :tool, content: five_token_msg, id: 1)
prompt.push(type: :user, content: five_token_msg)
dialect = TestDialect.new(prompt, "test")
dialect = TestDialect.new(prompt, llm_model)
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
trimmed = dialect.trim(prompt.messages)
@ -37,7 +39,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens")
prompt.push(type: :user, content: five_token_msg)
dialect = TestDialect.new(prompt, "test")
dialect = TestDialect.new(prompt, llm_model)
dialect.max_prompt_tokens = 15
trimmed = dialect.trim(prompt.messages)

View File

@ -3,8 +3,8 @@
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
let(:model_name) { "gemini-1.5-pro" }
let(:context) { DialectContext.new(described_class, model_name) }
fab!(:model) { Fabricate(:gemini_model) }
let(:context) { DialectContext.new(described_class, model) }
describe "#translate" do
it "translates a prompt written in our generic format to the Gemini format" do
@ -86,11 +86,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
it "trims content if it's getting too long" do
# testing truncation on 800k tokens is slow use model with less
context = DialectContext.new(described_class, "gemini-pro")
model.max_prompt_tokens = 16_384
context = DialectContext.new(described_class, model)
translated = context.long_user_input_scenario(length: 5_000)
expect(translated[:messages].last[:role]).to eq("user")
expect(translated[:messages].last.dig(:parts, :text).length).to be <
expect(translated[:messages].last.dig(:parts, 0, :text).length).to be <
context.long_message_text(length: 5_000).length
end
end

View File

@ -3,16 +3,7 @@ require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:url) { "https://api.anthropic.com/v1/messages" }
fab!(:model) do
Fabricate(
:llm_model,
url: "https://api.anthropic.com/v1/messages",
name: "claude-3-opus",
provider: "anthropic",
api_key: "123",
vision_enabled: true,
)
end
fab!(:model) { Fabricate(:anthropic_model, name: "claude-3-opus", vision_enabled: true) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do
@ -204,6 +195,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
it "supports non streaming tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
tool = {
name: "calculate",
description: "calculate something",
@ -224,8 +217,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
tools: [tool],
)
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
body = {
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
type: "message",
@ -252,7 +243,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
stub_request(:post, url).to_return(body: body)
result = proxy.generate(prompt, user: Discourse.system_user)
result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip
<function_calls>
@ -370,7 +361,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
},
).to_return(status: 200, body: body)
result = llm.generate(prompt, user: Discourse.system_user)
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
result = proxy.generate(prompt, user: Discourse.system_user)
expect(result).to eq("Hello!")
expected_body = {

View File

@ -8,9 +8,10 @@ class BedrockMock < EndpointMock
end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
subject(:endpoint) { described_class.new(model) }
fab!(:user)
fab!(:model) { Fabricate(:bedrock_model) }
let(:bedrock_mock) { BedrockMock.new(endpoint) }
@ -25,16 +26,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
Aws::EventStream::Encoder.new.encode(aws_message)
end
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
describe "function calling" do
it "supports old school xml function calls" do
SiteSetting.ai_anthropic_native_tool_call_models = ""
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
incomplete_tool_call = <<~XML.strip
<thinking>I should be ignored</thinking>
@ -112,7 +107,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
it "supports streaming function calls" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
@ -124,7 +119,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
type: "message",
role: "assistant",
model: "claude-3-haiku-20240307",
model: "claude-3-sonnet-20240307",
stop_sequence: nil,
usage: {
input_tokens: 840,
@ -281,9 +276,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
end
describe "Claude 3 Sonnet support" do
it "supports the sonnet model" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
describe "Claude 3 support" do
it "supports regular completions" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
@ -325,8 +320,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(log.response_tokens).to eq(20)
end
it "supports claude 3 sonnet streaming" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
it "supports claude 3 streaming" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil

View File

@ -2,7 +2,8 @@
require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") }
fab!(:cohere_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{cohere_model.id}") }
fab!(:user)
let(:prompt) do
@ -57,8 +58,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
prompt
end
before { SiteSetting.ai_cohere_api_key = "ABC" }
it "is able to trigger a tool" do
body = (<<~TEXT).strip
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
@ -184,7 +183,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
expect(audit.request_tokens).to eq(17)
expect(audit.response_tokens).to eq(22)
expect(audit.language_model).to eq("Cohere - command-r-plus")
expect(audit.language_model).to eq("command-r-plus")
end
it "is able to perform streaming completions" do

View File

@ -158,7 +158,7 @@ class EndpointsCompliance
end
def dialect(prompt: generic_prompt)
dialect_klass.new(prompt, endpoint.model)
dialect_klass.new(prompt, endpoint.llm_model)
end
def regular_mode_simple_prompt(mock)
@ -176,7 +176,7 @@ class EndpointsCompliance
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
expect(log.response_tokens).to eq(endpoint.llm_model.tokenizer_class.size(completion_response))
end
def regular_mode_tools(mock)
@ -206,7 +206,7 @@ class EndpointsCompliance
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
)
end
end

View File

@ -128,18 +128,9 @@ class GeminiMock < EndpointMock
end
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
subject(:endpoint) { described_class.new(model) }
fab!(:model) do
Fabricate(
:llm_model,
url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest",
name: "gemini-1.5-pro",
provider: "google",
api_key: "ABC",
vision_enabled: true,
)
end
fab!(:model) { Fabricate(:gemini_model, vision_enabled: true) }
fab!(:user)
@ -168,7 +159,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
req_body = nil
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:generateContent?key=ABC"
url = "#{model.url}:generateContent?key=123"
stub_request(:post, url).with(
body:
@ -221,7 +212,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
split = data.split("|")
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=ABC"
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
output = +""
gemini_mock.with_chunk_array_support do

View File

@ -22,7 +22,7 @@ class HuggingFaceMock < EndpointMock
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: request_body(prompt))
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -40,7 +40,7 @@ class HuggingFaceMock < EndpointMock
end
def stub_raw(chunks)
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
WebMock.stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200,
body: chunks,
)
@ -59,7 +59,7 @@ class HuggingFaceMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks)
@ -71,8 +71,7 @@ class HuggingFaceMock < EndpointMock
.default_options
.merge(messages: prompt)
.tap do |b|
b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
b[:max_tokens] = 63_991
b[:stream] = true if stream
end
.to_json
@ -80,15 +79,9 @@ class HuggingFaceMock < EndpointMock
end
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:endpoint) do
described_class.new(
"mistralai/Mistral-7B-Instruct-v0.2",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
subject(:endpoint) { described_class.new(hf_model) }
fab!(:hf_model)
fab!(:user)
let(:hf_mock) { HuggingFaceMock.new(endpoint) }

View File

@ -146,11 +146,10 @@ class OpenAiMock < EndpointMock
end
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:endpoint) do
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
end
subject(:endpoint) { described_class.new(model) }
fab!(:user)
fab!(:model) { Fabricate(:llm_model) }
let(:echo_tool) do
{
@ -175,7 +174,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
describe "repeat calls" do
it "can properly reset context" do
llm = DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo")
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
tools = [
{
@ -258,7 +257,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
describe "image support" do
it "can handle images" do
model = Fabricate(:llm_model, provider: "open_ai", vision_enabled: true)
model = Fabricate(:llm_model, vision_enabled: true)
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
prompt =
DiscourseAi::Completions::Prompt.new(

View File

@ -22,7 +22,7 @@ class VllmMock < EndpointMock
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -50,19 +50,16 @@ class VllmMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) do
described_class.new(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
subject(:endpoint) { described_class.new(llm_model) }
fab!(:llm_model) { Fabricate(:vllm_model) }
fab!(:user)
@ -78,15 +75,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
end
let(:dialect) do
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, model_name)
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, llm_model)
end
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do

View File

@ -5,12 +5,13 @@ RSpec.describe DiscourseAi::Completions::Llm do
described_class.new(
DiscourseAi::Completions::Dialects::OpenAiCompatible,
canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2",
model,
gateway: canned_response,
)
end
fab!(:user)
fab!(:model) { Fabricate(:llm_model) }
describe ".proxy" do
it "raises an exception when we can't proxy the model" do
@ -46,7 +47,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
)
result = +""
described_class
.proxy("open_ai:gpt-3.5-turbo")
.proxy("custom:#{model.id}")
.generate(prompt, user: user) { |partial| result << partial }
expect(result).to eq("Hello")
@ -57,12 +58,14 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
describe "#generate with fake model" do
fab!(:fake_model)
before do
DiscourseAi::Completions::Endpoints::Fake.delays = []
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
end
let(:llm) { described_class.proxy("fake:fake") }
let(:llm) { described_class.proxy("custom:#{fake_model.id}") }
let(:prompt) do
DiscourseAi::Completions::Prompt.new(

View File

@ -5,6 +5,8 @@ return if !defined?(DiscourseAutomation)
describe DiscourseAutomation do
let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) }
fab!(:llm_model)
fab!(:user)
fab!(:post)
@ -22,7 +24,7 @@ describe DiscourseAutomation do
it "can trigger via automation" do
add_automation_field("sender", user.username, type: "user")
add_automation_field("receivers", [user.username], type: "users")
add_automation_field("model", "gpt-4-turbo")
add_automation_field("model", "custom:#{llm_model.id}")
add_automation_field("title", "Weekly report")
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
@ -36,7 +38,7 @@ describe DiscourseAutomation do
it "can target a topic" do
add_automation_field("sender", user.username, type: "user")
add_automation_field("topic_id", "#{post.topic_id}")
add_automation_field("model", "gpt-4-turbo")
add_automation_field("model", "custom:#{llm_model.id}")
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
automation.trigger!

View File

@ -8,6 +8,8 @@ describe DiscourseAi::Automation::LlmTriage do
let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) }
fab!(:llm_model)
def add_automation_field(name, value, type: "text")
automation.fields.create!(
component: type,
@ -23,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
SiteSetting.tagging_enabled = true
add_automation_field("system_prompt", "hello %%POST%%")
add_automation_field("search_for_text", "bad")
add_automation_field("model", "gpt-4")
add_automation_field("model", "custom:#{llm_model.id}")
add_automation_field("category", category.id, type: "category")
add_automation_field("tags", %w[aaa bbb], type: "tags")
add_automation_field("hide_topic", true, type: "boolean")

View File

@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
SiteSetting.ai_bot_enabled = true
end
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-4") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_4.name) }
let!(:user) { Fabricate(:user) }
@ -38,7 +38,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
toggle_enabled_bots(bots: [fake])
Group.refresh_automatic_groups!
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model("fake")
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model(fake.name)
AiPersona.create!(
name: "TestPersona",
top_p: 0.5,

View File

@ -336,6 +336,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
context "when RAG is running with a question consolidator" do
let(:consolidated_question) { "what is the time in france?" }
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do
context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
@ -350,7 +352,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
name: "custom",
rag_conversation_chunks: 3,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
question_consolidator_llm: "fake:fake",
question_consolidator_llm: "custom:#{llm_model.id}",
)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])

View File

@ -4,6 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) }
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") }
fab!(:opus_model) { Fabricate(:anthropic_model) }
fab!(:bot_user) do
toggle_enabled_bots(bots: [claude_2])
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
system_prompt: "You are a helpful bot",
vision_enabled: true,
vision_max_pixels: 1_000,
default_llm: "anthropic:claude-3-opus",
default_llm: "custom:#{opus_model.id}",
mentionable: true,
)
end
@ -211,7 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
)
persona.create_user!
persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
persona.update!(default_llm: "custom:#{claude_2.id}", mentionable: true)
persona
end
@ -228,7 +229,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
SiteSetting.ai_bot_enabled = true
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
Group.refresh_automatic_groups!
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
persona.update!(allow_chat: true, mentionable: true, default_llm: "custom:#{opus_model.id}")
end
it "should behave in a sane way when threading is enabled" do
@ -342,7 +343,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
persona.update!(
allow_chat: true,
mentionable: false,
default_llm: "anthropic:claude-3-opus",
default_llm: "custom:#{opus_model.id}",
)
SiteSetting.ai_bot_enabled = true
end
@ -517,7 +518,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"],
llm: "anthropic:claude-2",
llm: "custom:#{claude_2.id}",
) do
post =
create_post(
@ -552,7 +553,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# title is queued first, ensures it uses the llm targeted via target_usernames not claude
DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"],
llm: "open_ai:gpt-3.5-turbo",
llm: "custom:#{gpt_35_turbo.id}",
) do
post =
create_post(
@ -584,7 +585,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# replies as correct persona if replying direct to persona
DiscourseAi::Completions::Llm.with_prepared_responses(
["Another reply"],
llm: "open_ai:gpt-3.5-turbo",
llm: "custom:#{gpt_35_turbo.id}",
) do
create_post(
raw: "Please ignore this bot, I am replying to a user",

View File

@ -1,7 +1,7 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{Fabricate(:fake_model).id}") }
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
fab!(:user)

View File

@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
SiteSetting.ai_openai_api_key = "abc"
end
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:progress_blk) { Proc.new {} }
let(:dall_e) do

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true }
describe "#process" do

View File

@ -1,12 +1,10 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
before do
SiteSetting.ai_bot_enabled = true
SiteSetting.ai_openai_api_key = "asd"
end
before { SiteSetting.ai_bot_enabled = true }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} }
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }

View File

@ -3,7 +3,8 @@
require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) do
described_class.new(

View File

@ -4,7 +4,8 @@ require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
let(:bot_user) { Fabricate(:user) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
context "with #sort_and_shorten_diff" do

View File

@ -4,7 +4,8 @@ require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
let(:bot_user) { Fabricate(:user) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
context "with valid search results" do

View File

@ -3,7 +3,8 @@
require "rails_helper"
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
fab!(:llm_model)
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) do
described_class.new(

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Google do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} }
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }

View File

@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
describe "#process" do
it "can generate correct info" do

View File

@ -1,8 +1,9 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do
SiteSetting.ai_bot_enabled = true

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Read do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
fab!(:parent_category) { Fabricate(:category, name: "animals") }

View File

@ -1,13 +1,13 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
fab!(:gpt_35_bot) { Fabricate(:llm_model, name: "gpt-3.5-turbo") }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do
SiteSetting.ai_bot_enabled = true
toggle_enabled_bots(bots: [gpt_35_bot])
toggle_enabled_bots(bots: [llm_model])
end
def search_settings(query)

View File

@ -4,10 +4,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
before { SearchIndexer.enable }
after { SearchIndexer.disable }
before { SiteSetting.ai_openai_api_key = "asd" }
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} }
fab!(:admin)

View File

@ -9,8 +9,10 @@ def has_rg?
end
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:progress_blk) { Proc.new {} }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,8 +1,9 @@
#frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Time do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before { SiteSetting.ai_bot_enabled = true }

View File

@ -1,13 +1,11 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") }
fab!(:llm_model)
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
before do
SiteSetting.ai_openai_api_key = "asd"
SiteSetting.ai_bot_enabled = true
end
before { SiteSetting.ai_bot_enabled = true }
describe "#invoke" do
it "can retrieve the content of a webpage and returns the processed text" do

View File

@ -1,6 +1,7 @@
# frozen_string_literal: true
describe DiscourseAi::Automation::LlmTriage do
fab!(:post)
fab!(:llm_model)
def triage(**args)
DiscourseAi::Automation::LlmTriage.handle(**args)
@ -10,7 +11,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -24,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -40,7 +41,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_id: category.id,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -55,7 +56,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%",
search_for_text: "bad",
canned_reply: "test canned reply 123",
@ -73,7 +74,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%",
search_for_text: "bad",
flag_post: true,
@ -89,7 +90,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
triage(
post: post,
model: "gpt-4",
model: "custom:#{llm_model.id}",
system_prompt: "test %%POST%%",
search_for_text: "bad",
flag_post: true,

View File

@ -32,6 +32,8 @@ module DiscourseAi
fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) }
fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) }
fab!(:llm_model)
describe "#run!" do
it "is able to generate email reports" do
freeze_time
@ -41,7 +43,7 @@ module DiscourseAi
sender_username: user.username,
receivers: ["fake@discourse.com"],
title: "test report %DATE%",
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -78,7 +80,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -123,7 +125,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -166,7 +168,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -194,7 +196,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "gpt-4",
model: "custom:#{llm_model.id}",
category_ids: nil,
tags: nil,
allow_secure_categories: false,

View File

@ -2,7 +2,7 @@
RSpec.describe AiTool do
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
let(:llm) { DiscourseAi::Completions::Llm.proxy_from_obj(llm_model) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
def create_tool(parameters: nil, script: nil)
AiTool.create!(

View File

@ -1,64 +0,0 @@
# frozen_string_literal: true
RSpec.describe LlmModel do
describe ".seed_srv_backed_model" do
before do
SiteSetting.ai_vllm_endpoint_srv = "srv.llm.service."
SiteSetting.ai_vllm_api_key = "123"
end
context "when the model doesn't exist yet" do
it "creates it" do
described_class.seed_srv_backed_model
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model).to be_present
expect(llm_model.name).to eq("mistralai/Mixtral-8x7B-Instruct-v0.1")
expect(llm_model.api_key).to eq(SiteSetting.ai_vllm_api_key)
end
end
context "when the model already exists" do
before { described_class.seed_srv_backed_model }
context "when the API key setting changes" do
it "updates it" do
new_key = "456"
SiteSetting.ai_vllm_api_key = new_key
described_class.seed_srv_backed_model
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model.api_key).to eq(new_key)
end
end
context "when the SRV is no longer defined" do
it "deletes the LlmModel" do
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
expect(llm_model).to be_present
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
expect { llm_model.reload }.to raise_exception(ActiveRecord::RecordNotFound)
end
it "disabled the bot user" do
SiteSetting.ai_bot_enabled = true
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
llm_model.update!(enabled_chat_bot: true)
llm_model.toggle_companion_user
user = llm_model.user
expect(user).to be_present
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
expect { user.reload }.to raise_exception(ActiveRecord::RecordNotFound)
end
end
end
end
end

View File

@ -8,7 +8,7 @@ module DiscourseAi::ChatBotHelper
end
def assign_fake_provider_to(setting_name)
Fabricate(:llm_model, provider: "fake", name: "fake").tap do |fake_llm|
Fabricate(:fake_model).tap do |fake_llm|
SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}")
end
end