DEV: Remove old code now that features rely on LlmModels. (#729)
* DEV: Remove old code now that features rely on LlmModels. * Hide old settings and migrate persona llm overrides * Remove shadowing special URL + seeding code. Use srv:// prefix instead.
This commit is contained in:
parent
73a2b15e91
commit
bed044448c
|
@ -110,9 +110,7 @@ module DiscourseAi
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = updating ? updating.provider : permitted[:provider]
|
provider = updating ? updating.provider : permitted[:provider]
|
||||||
permit_url =
|
permit_url = provider != LlmModel::BEDROCK_PROVIDER_NAME
|
||||||
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
|
|
||||||
provider != LlmModel::BEDROCK_PROVIDER_NAME
|
|
||||||
|
|
||||||
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
|
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
|
||||||
|
|
||||||
|
|
|
@ -2,44 +2,10 @@
|
||||||
|
|
||||||
class LlmModel < ActiveRecord::Base
|
class LlmModel < ActiveRecord::Base
|
||||||
FIRST_BOT_USER_ID = -1200
|
FIRST_BOT_USER_ID = -1200
|
||||||
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
|
|
||||||
BEDROCK_PROVIDER_NAME = "aws_bedrock"
|
BEDROCK_PROVIDER_NAME = "aws_bedrock"
|
||||||
|
|
||||||
belongs_to :user
|
belongs_to :user
|
||||||
|
|
||||||
validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] }, if: :url_changed?
|
|
||||||
|
|
||||||
def self.seed_srv_backed_model
|
|
||||||
srv = SiteSetting.ai_vllm_endpoint_srv
|
|
||||||
srv_model = find_by(url: RESERVED_VLLM_SRV_URL)
|
|
||||||
|
|
||||||
if srv.present?
|
|
||||||
if srv_model.present?
|
|
||||||
current_key = SiteSetting.ai_vllm_api_key
|
|
||||||
srv_model.update!(api_key: current_key) if current_key != srv_model.api_key
|
|
||||||
else
|
|
||||||
record =
|
|
||||||
new(
|
|
||||||
display_name: "vLLM SRV LLM",
|
|
||||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
provider: "vllm",
|
|
||||||
tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer",
|
|
||||||
url: RESERVED_VLLM_SRV_URL,
|
|
||||||
max_prompt_tokens: 8000,
|
|
||||||
api_key: SiteSetting.ai_vllm_api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
record.save(validate: false) # Ignore reserved URL validation
|
|
||||||
end
|
|
||||||
else
|
|
||||||
# Clean up companion users
|
|
||||||
srv_model&.enabled_chat_bot = false
|
|
||||||
srv_model&.toggle_companion_user
|
|
||||||
|
|
||||||
srv_model&.destroy!
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.provider_params
|
def self.provider_params
|
||||||
{
|
{
|
||||||
aws_bedrock: {
|
aws_bedrock: {
|
||||||
|
@ -54,7 +20,7 @@ class LlmModel < ActiveRecord::Base
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_llm
|
def to_llm
|
||||||
DiscourseAi::Completions::Llm.proxy_from_obj(self)
|
DiscourseAi::Completions::Llm.proxy("custom:#{id}")
|
||||||
end
|
end
|
||||||
|
|
||||||
def toggle_companion_user
|
def toggle_companion_user
|
||||||
|
|
|
@ -19,6 +19,6 @@ class LlmModelSerializer < ApplicationSerializer
|
||||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||||
|
|
||||||
def shadowed_by_srv
|
def shadowed_by_srv
|
||||||
object.url == LlmModel::RESERVED_VLLM_SRV_URL
|
object.url.to_s.starts_with?("srv://")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -60,17 +60,9 @@ export default class AiLlmEditorForm extends Component {
|
||||||
return this.testRunning || this.testResult !== null;
|
return this.testRunning || this.testResult !== null;
|
||||||
}
|
}
|
||||||
|
|
||||||
get displaySRVWarning() {
|
|
||||||
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
|
|
||||||
}
|
|
||||||
|
|
||||||
get canEditURL() {
|
get canEditURL() {
|
||||||
// Explicitly false.
|
// Explicitly false.
|
||||||
if (this.metaProviderParams.url_editable === false) {
|
return this.metaProviderParams.url_editable !== false;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@computed("args.model.provider")
|
@computed("args.model.provider")
|
||||||
|
@ -174,12 +166,6 @@ export default class AiLlmEditorForm extends Component {
|
||||||
}
|
}
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
{{#if this.displaySRVWarning}}
|
|
||||||
<div class="alert alert-info">
|
|
||||||
{{icon "exclamation-circle"}}
|
|
||||||
{{I18n.t "discourse_ai.llms.srv_warning"}}
|
|
||||||
</div>
|
|
||||||
{{/if}}
|
|
||||||
<form class="form-horizontal ai-llm-editor">
|
<form class="form-horizontal ai-llm-editor">
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||||
|
|
|
@ -237,7 +237,6 @@ en:
|
||||||
confirm_delete: Are you sure you want to delete this model?
|
confirm_delete: Are you sure you want to delete this model?
|
||||||
delete: Delete
|
delete: Delete
|
||||||
|
|
||||||
srv_warning: This LLM points to an SRV record, and its URL is not editable. You have to update the hidden "ai_vllm_endpoint_srv" setting instead.
|
|
||||||
preconfigured_llms: "Select your LLM"
|
preconfigured_llms: "Select your LLM"
|
||||||
preconfigured:
|
preconfigured:
|
||||||
none: "Configure manually..."
|
none: "Configure manually..."
|
||||||
|
|
|
@ -96,21 +96,35 @@ discourse_ai:
|
||||||
- opennsfw2
|
- opennsfw2
|
||||||
- nsfw_detector
|
- nsfw_detector
|
||||||
|
|
||||||
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
|
ai_openai_gpt35_url:
|
||||||
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
ai_openai_gpt4o_url: "https://api.openai.com/v1/chat/completions"
|
hidden: true
|
||||||
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
|
ai_openai_gpt35_16k_url:
|
||||||
ai_openai_gpt4_32k_url: "https://api.openai.com/v1/chat/completions"
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
ai_openai_gpt4_turbo_url: "https://api.openai.com/v1/chat/completions"
|
hidden: true
|
||||||
|
ai_openai_gpt4o_url:
|
||||||
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
|
hidden: true
|
||||||
|
ai_openai_gpt4_url:
|
||||||
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
|
hidden: true
|
||||||
|
ai_openai_gpt4_32k_url:
|
||||||
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
|
hidden: true
|
||||||
|
ai_openai_gpt4_turbo_url:
|
||||||
|
default: "https://api.openai.com/v1/chat/completions"
|
||||||
|
hidden: true
|
||||||
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations"
|
||||||
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||||
ai_openai_organization: ""
|
ai_openai_organization:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_openai_api_key:
|
ai_openai_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
ai_anthropic_api_key:
|
ai_anthropic_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
hidden: true
|
||||||
ai_anthropic_native_tool_call_models:
|
ai_anthropic_native_tool_call_models:
|
||||||
type: list
|
type: list
|
||||||
list_type: compact
|
list_type: compact
|
||||||
|
@ -123,7 +137,7 @@ discourse_ai:
|
||||||
- claude-3-5-sonnet
|
- claude-3-5-sonnet
|
||||||
ai_cohere_api_key:
|
ai_cohere_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
hidden: true
|
||||||
ai_stability_api_key:
|
ai_stability_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
@ -140,13 +154,16 @@ discourse_ai:
|
||||||
- "stable-diffusion-v1-5"
|
- "stable-diffusion-v1-5"
|
||||||
ai_hugging_face_api_url:
|
ai_hugging_face_api_url:
|
||||||
default: ""
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_hugging_face_api_key:
|
ai_hugging_face_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
hidden: true
|
||||||
ai_hugging_face_token_limit:
|
ai_hugging_face_token_limit:
|
||||||
default: 4096
|
default: 4096
|
||||||
|
hidden: true
|
||||||
ai_hugging_face_model_display_name:
|
ai_hugging_face_model_display_name:
|
||||||
default: ""
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_hugging_face_tei_endpoint:
|
ai_hugging_face_tei_endpoint:
|
||||||
default: ""
|
default: ""
|
||||||
ai_hugging_face_tei_endpoint_srv:
|
ai_hugging_face_tei_endpoint_srv:
|
||||||
|
@ -167,11 +184,13 @@ discourse_ai:
|
||||||
ai_bedrock_access_key_id:
|
ai_bedrock_access_key_id:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
hidden: true
|
||||||
ai_bedrock_secret_access_key:
|
ai_bedrock_secret_access_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
hidden: true
|
||||||
ai_bedrock_region:
|
ai_bedrock_region:
|
||||||
default: "us-east-1"
|
default: "us-east-1"
|
||||||
|
hidden: true
|
||||||
ai_cloudflare_workers_account_id:
|
ai_cloudflare_workers_account_id:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
@ -180,13 +199,16 @@ discourse_ai:
|
||||||
secret: true
|
secret: true
|
||||||
ai_gemini_api_key:
|
ai_gemini_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
hidden: true
|
||||||
ai_vllm_endpoint:
|
ai_vllm_endpoint:
|
||||||
default: ""
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_vllm_endpoint_srv:
|
ai_vllm_endpoint_srv:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_vllm_api_key: ""
|
ai_vllm_api_key:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_llava_endpoint:
|
ai_llava_endpoint:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
begin
|
|
||||||
LlmModel.seed_srv_backed_model
|
|
||||||
rescue PG::UndefinedColumn => e
|
|
||||||
# If this code runs before migrations, an attribute might be missing.
|
|
||||||
Rails.logger.warn("Failed to seed SRV-Backed LLM: #{e.meesage}")
|
|
||||||
end
|
|
|
@ -25,8 +25,9 @@ class MigrateVisionLlms < ActiveRecord::Migration[7.1]
|
||||||
).first
|
).first
|
||||||
|
|
||||||
if current_value && current_value != "llava"
|
if current_value && current_value != "llava"
|
||||||
|
model_name = current_value.split(":").last
|
||||||
llm_model =
|
llm_model =
|
||||||
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: current_value).first
|
DB.query_single("SELECT id FROM llm_models WHERE name = :model", model: model_name).first
|
||||||
|
|
||||||
if llm_model
|
if llm_model
|
||||||
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model
|
DB.exec(<<~SQL, new: "custom:#{llm_model}") if llm_model
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
class MigratePersonaLlmOverride < ActiveRecord::Migration[7.1]
|
||||||
|
def up
|
||||||
|
fields_to_update = DB.query(<<~SQL)
|
||||||
|
SELECT id, default_llm
|
||||||
|
FROM ai_personas
|
||||||
|
WHERE default_llm IS NOT NULL
|
||||||
|
SQL
|
||||||
|
|
||||||
|
return if fields_to_update.empty?
|
||||||
|
|
||||||
|
updated_fields =
|
||||||
|
fields_to_update
|
||||||
|
.map do |field|
|
||||||
|
llm_model_id = matching_llm_model(field.default_llm)
|
||||||
|
|
||||||
|
"(#{field.id}, 'custom:#{llm_model_id}')" if llm_model_id
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
|
||||||
|
return if updated_fields.empty?
|
||||||
|
|
||||||
|
DB.exec(<<~SQL)
|
||||||
|
UPDATE ai_personas
|
||||||
|
SET default_llm = new_fields.new_default_llm
|
||||||
|
FROM (VALUES #{updated_fields.join(", ")}) AS new_fields(id, new_default_llm)
|
||||||
|
WHERE new_fields.id::bigint = ai_personas.id
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def matching_llm_model(model)
|
||||||
|
provider = model.split(":").first
|
||||||
|
model_name = model.split(":").last
|
||||||
|
|
||||||
|
return if provider == "custom"
|
||||||
|
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT id FROM llm_models WHERE name = :name AND provider = :provider",
|
||||||
|
{ name: model_name, provider: provider },
|
||||||
|
).first
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
raise ActiveRecord::IrreversibleMigration
|
||||||
|
end
|
||||||
|
end
|
|
@ -5,19 +5,15 @@ module DiscourseAi
|
||||||
module Dialects
|
module Dialects
|
||||||
class ChatGpt < Dialect
|
class ChatGpt < Dialect
|
||||||
class << self
|
class << self
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_provider)
|
||||||
model_name.starts_with?("gpt-")
|
model_provider == "open_ai" || model_provider == "azure"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
true
|
llm_model.provider == "open_ai" || llm_model.provider == "azure"
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
|
@ -30,19 +26,17 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
|
||||||
|
|
||||||
# provide a buffer of 120 tokens - our function counting is not
|
# provide a buffer of 120 tokens - our function counting is not
|
||||||
# 100% accurate and getting numbers to align exactly is very hard
|
# 100% accurate and getting numbers to align exactly is very hard
|
||||||
buffer = (opts[:max_tokens] || 2500) + 50
|
buffer = (opts[:max_tokens] || 2500) + 50
|
||||||
|
|
||||||
if tools.present?
|
if tools.present?
|
||||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||||
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
|
@function_size ||= llm_model.tokenizer_class.size(tools.to_json.to_s)
|
||||||
buffer += @function_size
|
buffer += @function_size
|
||||||
end
|
end
|
||||||
|
|
||||||
model_max_tokens - buffer
|
llm_model.max_prompt_tokens - buffer
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
@ -105,24 +99,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
|
||||||
|
|
||||||
def model_max_tokens
|
|
||||||
case model_name
|
|
||||||
when "gpt-3.5-turbo-16k"
|
|
||||||
16_384
|
|
||||||
when "gpt-4"
|
|
||||||
8192
|
|
||||||
when "gpt-4-32k"
|
|
||||||
32_768
|
|
||||||
when "gpt-4-turbo"
|
|
||||||
131_072
|
|
||||||
when "gpt-4o"
|
|
||||||
131_072
|
|
||||||
else
|
|
||||||
8192
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
||||||
module Dialects
|
module Dialects
|
||||||
class Claude < Dialect
|
class Claude < Dialect
|
||||||
class << self
|
class << self
|
||||||
def can_translate?(model_name)
|
def can_translate?(provider_name)
|
||||||
model_name.start_with?("claude") || model_name.start_with?("anthropic")
|
provider_name == "anthropic" || provider_name == "aws_bedrock"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -26,10 +26,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
messages = super
|
messages = super
|
||||||
|
|
||||||
|
@ -61,14 +57,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
llm_model.max_prompt_tokens
|
||||||
|
|
||||||
# Longer term it will have over 1 million
|
|
||||||
200_000 # Claude-3 has a 200k context window for now
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
|
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name)
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
|
@ -6,18 +6,12 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class Command < Dialect
|
class Command < Dialect
|
||||||
class << self
|
def self.can_translate?(model_provider)
|
||||||
def can_translate?(model_name)
|
model_provider == "cohere"
|
||||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
messages = super
|
messages = super
|
||||||
|
|
||||||
|
@ -68,20 +62,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
llm_model.max_prompt_tokens
|
||||||
|
|
||||||
case model_name
|
|
||||||
when "command-light"
|
|
||||||
4096
|
|
||||||
when "command"
|
|
||||||
8192
|
|
||||||
when "command-r"
|
|
||||||
131_072
|
|
||||||
when "command-r-plus"
|
|
||||||
131_072
|
|
||||||
else
|
|
||||||
8192
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
|
@ -99,7 +80,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
|
|
|
@ -5,7 +5,7 @@ module DiscourseAi
|
||||||
module Dialects
|
module Dialects
|
||||||
class Dialect
|
class Dialect
|
||||||
class << self
|
class << self
|
||||||
def can_translate?(_model_name)
|
def can_translate?(model_provider)
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ module DiscourseAi
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
def dialect_for(model_name)
|
def dialect_for(model_provider)
|
||||||
dialects = []
|
dialects = []
|
||||||
|
|
||||||
if Rails.env.test? || Rails.env.development?
|
if Rails.env.test? || Rails.env.development?
|
||||||
|
@ -28,26 +28,21 @@ module DiscourseAi
|
||||||
|
|
||||||
dialects = dialects.concat(all_dialects)
|
dialects = dialects.concat(all_dialects)
|
||||||
|
|
||||||
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
dialect = dialects.find { |d| d.can_translate?(model_provider) }
|
||||||
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
||||||
|
|
||||||
dialect
|
dialect
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
|
def initialize(generic_prompt, llm_model, opts: {})
|
||||||
@prompt = generic_prompt
|
@prompt = generic_prompt
|
||||||
@model_name = model_name
|
|
||||||
@opts = opts
|
@opts = opts
|
||||||
@llm_model = llm_model
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
raise NotImplemented
|
|
||||||
end
|
|
||||||
|
|
||||||
def can_end_with_assistant_msg?
|
def can_end_with_assistant_msg?
|
||||||
false
|
false
|
||||||
end
|
end
|
||||||
|
@ -57,7 +52,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vision_support?
|
def vision_support?
|
||||||
llm_model&.vision_enabled?
|
llm_model.vision_enabled?
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools
|
def tools
|
||||||
|
@ -88,12 +83,12 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :model_name, :opts, :llm_model
|
attr_reader :opts, :llm_model
|
||||||
|
|
||||||
def trim_messages(messages)
|
def trim_messages(messages)
|
||||||
prompt_limit = max_prompt_tokens
|
prompt_limit = max_prompt_tokens
|
||||||
current_token_count = 0
|
current_token_count = 0
|
||||||
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
message_step_size = (prompt_limit / 25).to_i * -1
|
||||||
|
|
||||||
trimmed_messages = []
|
trimmed_messages = []
|
||||||
|
|
||||||
|
@ -157,7 +152,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(msg)
|
def calculate_message_token(msg)
|
||||||
self.tokenizer.size(msg[:content].to_s)
|
llm_model.tokenizer_class.size(msg[:content].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools_dialect
|
def tools_dialect
|
||||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
||||||
module Dialects
|
module Dialects
|
||||||
class Gemini < Dialect
|
class Gemini < Dialect
|
||||||
class << self
|
class << self
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_provider)
|
||||||
%w[gemini-pro gemini-1.5-pro gemini-1.5-flash].include?(model_name)
|
model_provider == "google"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -14,10 +14,6 @@ module DiscourseAi
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
|
||||||
end
|
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
# Gemini complains if we don't alternate model/user roles.
|
# Gemini complains if we don't alternate model/user roles.
|
||||||
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
||||||
|
@ -74,24 +70,17 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
llm_model.max_prompt_tokens
|
||||||
|
|
||||||
if model_name.start_with?("gemini-1.5")
|
|
||||||
# technically we support 1 million tokens, but we're being conservative
|
|
||||||
800_000
|
|
||||||
else
|
|
||||||
16_384 # 50% of model tokens
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
llm_model.tokenizer_class.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def beta_api?
|
def beta_api?
|
||||||
@beta_api ||= model_name.start_with?("gemini-1.5")
|
@beta_api ||= llm_model.name.start_with?("gemini-1.5")
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Anthropic < Base
|
class Anthropic < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "anthropic"
|
||||||
endpoint_name == "anthropic"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_anthropic_api_key]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_anthropic_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"Anthropic - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -29,7 +15,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def default_options(dialect)
|
def default_options(dialect)
|
||||||
mapped_model =
|
mapped_model =
|
||||||
case model
|
case llm_model.name
|
||||||
when "claude-2"
|
when "claude-2"
|
||||||
"claude-2.1"
|
"claude-2.1"
|
||||||
when "claude-instant-1"
|
when "claude-instant-1"
|
||||||
|
@ -43,7 +29,7 @@ module DiscourseAi
|
||||||
when "claude-3-5-sonnet"
|
when "claude-3-5-sonnet"
|
||||||
"claude-3-5-sonnet-20240620"
|
"claude-3-5-sonnet-20240620"
|
||||||
else
|
else
|
||||||
model
|
llm_model.name
|
||||||
end
|
end
|
||||||
|
|
||||||
options = { model: mapped_model, max_tokens: 3_000 }
|
options = { model: mapped_model, max_tokens: 3_000 }
|
||||||
|
@ -74,9 +60,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
|
URI(llm_model.url)
|
||||||
|
|
||||||
URI(url)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -94,7 +78,7 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"anthropic-version" => "2023-06-01",
|
"anthropic-version" => "2023-06-01",
|
||||||
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
|
"x-api-key" => llm_model.api_key,
|
||||||
"content-type" => "application/json",
|
"content-type" => "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,24 +6,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class AwsBedrock < Base
|
class AwsBedrock < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "aws_bedrock"
|
||||||
endpoint_name == "aws_bedrock"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model)
|
|
||||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
|
||||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
|
||||||
SiteSetting.ai_bedrock_region.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"AWS Bedrock - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -62,37 +46,28 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
if llm_model
|
region = llm_model.lookup_custom_param("region")
|
||||||
region = llm_model.lookup_custom_param("region")
|
|
||||||
|
|
||||||
api_url =
|
bedrock_model_id =
|
||||||
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
|
case llm_model.name
|
||||||
else
|
when "claude-2"
|
||||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
"anthropic.claude-v2:1"
|
||||||
#
|
when "claude-3-haiku"
|
||||||
# FYI there is a 2.0 version of Claude, very little need to support it given
|
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||||
# haiku/sonnet are better fits anyway, we map to claude-2.1
|
when "claude-3-sonnet"
|
||||||
bedrock_model_id =
|
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
case model
|
when "claude-instant-1"
|
||||||
when "claude-2"
|
"anthropic.claude-instant-v1"
|
||||||
"anthropic.claude-v2:1"
|
when "claude-3-opus"
|
||||||
when "claude-3-haiku"
|
"anthropic.claude-3-opus-20240229-v1:0"
|
||||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
when "claude-3-5-sonnet"
|
||||||
when "claude-3-sonnet"
|
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
else
|
||||||
when "claude-instant-1"
|
llm_model.name
|
||||||
"anthropic.claude-instant-v1"
|
end
|
||||||
when "claude-3-opus"
|
|
||||||
"anthropic.claude-3-opus-20240229-v1:0"
|
|
||||||
when "claude-3-5-sonnet"
|
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
|
||||||
else
|
|
||||||
model
|
|
||||||
end
|
|
||||||
|
|
||||||
api_url =
|
api_url =
|
||||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||||
end
|
|
||||||
|
|
||||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||||
|
|
||||||
|
@ -114,11 +89,9 @@ module DiscourseAi
|
||||||
|
|
||||||
signer =
|
signer =
|
||||||
Aws::Sigv4::Signer.new(
|
Aws::Sigv4::Signer.new(
|
||||||
access_key_id:
|
access_key_id: llm_model.lookup_custom_param("access_key_id"),
|
||||||
llm_model&.lookup_custom_param("access_key_id") ||
|
region: llm_model.lookup_custom_param("region"),
|
||||||
SiteSetting.ai_bedrock_access_key_id,
|
secret_access_key: llm_model.api_key,
|
||||||
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
|
|
||||||
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
|
|
||||||
service: "bedrock",
|
service: "bedrock",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,39 +30,12 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def configuration_hint
|
def can_contact?(_model_provider)
|
||||||
settings = dependant_setting_names
|
|
||||||
I18n.t(
|
|
||||||
"discourse_ai.llm.endpoints.configuration_hint",
|
|
||||||
settings: settings.join(", "),
|
|
||||||
count: settings.length,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
def display_name(model_name)
|
|
||||||
to_display = endpoint_name(model_name)
|
|
||||||
|
|
||||||
return to_display if correctly_configured?(model_name)
|
|
||||||
|
|
||||||
I18n.t("discourse_ai.llm.endpoints.not_configured", display_name: to_display)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(_model_name)
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def can_contact?(_endpoint_name)
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(model_name, tokenizer, llm_model: nil)
|
def initialize(llm_model)
|
||||||
@model = model_name
|
|
||||||
@tokenizer = tokenizer
|
|
||||||
@llm_model = llm_model
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -136,7 +109,7 @@ module DiscourseAi
|
||||||
topic_id: dialect.prompt.topic_id,
|
topic_id: dialect.prompt.topic_id,
|
||||||
post_id: dialect.prompt.post_id,
|
post_id: dialect.prompt.post_id,
|
||||||
feature_name: feature_name,
|
feature_name: feature_name,
|
||||||
language_model: self.class.endpoint_name(@model),
|
language_model: llm_model.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if !@streaming_mode
|
if !@streaming_mode
|
||||||
|
@ -323,10 +296,14 @@ module DiscourseAi
|
||||||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :tokenizer, :model, :llm_model
|
attr_reader :llm_model
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model.tokenizer_class
|
||||||
|
end
|
||||||
|
|
||||||
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -6,10 +6,6 @@ module DiscourseAi
|
||||||
class CannedResponse
|
class CannedResponse
|
||||||
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
||||||
|
|
||||||
def self.can_contact?(_)
|
|
||||||
Rails.env.test?
|
|
||||||
end
|
|
||||||
|
|
||||||
def initialize(responses)
|
def initialize(responses)
|
||||||
@responses = responses
|
@responses = responses
|
||||||
@completions = 0
|
@completions = 0
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Cohere < Base
|
class Cohere < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "cohere"
|
||||||
endpoint_name == "cohere"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_cohere_api_key]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_cohere_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"Cohere - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -39,9 +25,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
|
URI(llm_model.url)
|
||||||
|
|
||||||
URI(url)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -59,7 +43,7 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
|
"Authorization" => "Bearer #{llm_model.api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
|
|
@ -4,20 +4,6 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Fake < Base
|
class Fake < Base
|
||||||
class << self
|
|
||||||
def can_contact?(endpoint_name)
|
|
||||||
endpoint_name == "fake"
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
true
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(_model_name)
|
|
||||||
"Test - fake model"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
STOCK_CONTENT = <<~TEXT
|
STOCK_CONTENT = <<~TEXT
|
||||||
# Discourse Markdown Styles Showcase
|
# Discourse Markdown Styles Showcase
|
||||||
|
|
||||||
|
@ -75,6 +61,10 @@ module DiscourseAi
|
||||||
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
|
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
|
def self.can_contact?(model_provider)
|
||||||
|
model_provider == "fake"
|
||||||
|
end
|
||||||
|
|
||||||
def self.with_fake_content(content)
|
def self.with_fake_content(content)
|
||||||
@fake_content = content
|
@fake_content = content
|
||||||
yield
|
yield
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Gemini < Base
|
class Gemini < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "google"
|
||||||
endpoint_name == "google"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_gemini_api_key]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_gemini_api_key.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"Google - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
|
@ -59,21 +45,8 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
if llm_model
|
url = llm_model.url
|
||||||
url = llm_model.url
|
key = llm_model.api_key
|
||||||
else
|
|
||||||
mapped_model = model
|
|
||||||
if model == "gemini-1.5-pro"
|
|
||||||
mapped_model = "gemini-1.5-pro-latest"
|
|
||||||
elsif model == "gemini-1.5-flash"
|
|
||||||
mapped_model = "gemini-1.5-flash-latest"
|
|
||||||
elsif model == "gemini-1.0-pro"
|
|
||||||
mapped_model = "gemini-pro-latest"
|
|
||||||
end
|
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
|
||||||
end
|
|
||||||
|
|
||||||
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
|
||||||
|
|
||||||
if @streaming_mode
|
if @streaming_mode
|
||||||
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class HuggingFace < Base
|
class HuggingFace < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "hugging_face"
|
||||||
endpoint_name == "hugging_face"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_hugging_face_api_url]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_hugging_face_api_url.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"Hugging Face - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ model: model, temperature: 0.7 }
|
{ model: llm_model.name, temperature: 0.7 }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -44,7 +30,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
|
URI(llm_model.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
|
@ -53,8 +39,7 @@ module DiscourseAi
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap do |payload|
|
.tap do |payload|
|
||||||
if !payload[:max_tokens]
|
if !payload[:max_tokens]
|
||||||
token_limit =
|
token_limit = llm_model.max_prompt_tokens
|
||||||
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
|
|
||||||
|
|
||||||
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
||||||
end
|
end
|
||||||
|
@ -64,7 +49,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
|
api_key = llm_model.api_key
|
||||||
|
|
||||||
headers =
|
headers =
|
||||||
{ "Content-Type" => "application/json" }.tap do |h|
|
{ "Content-Type" => "application/json" }.tap do |h|
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Ollama < Base
|
class Ollama < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "ollama"
|
||||||
endpoint_name == "ollama"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_ollama_endpoint]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_ollama_endpoint.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"Ollama - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ max_tokens: 2000, model: model }
|
{ max_tokens: 2000, model: llm_model.name }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -48,7 +34,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
URI(llm_model.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
|
|
|
@ -4,56 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class OpenAi < Base
|
class OpenAi < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
%w[open_ai azure].include?(model_provider)
|
||||||
endpoint_name == "open_ai"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[
|
|
||||||
ai_openai_api_key
|
|
||||||
ai_openai_gpt4o_url
|
|
||||||
ai_openai_gpt4_32k_url
|
|
||||||
ai_openai_gpt4_turbo_url
|
|
||||||
ai_openai_gpt4_url
|
|
||||||
ai_openai_gpt4_url
|
|
||||||
ai_openai_gpt35_16k_url
|
|
||||||
ai_openai_gpt35_url
|
|
||||||
]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(model_name)
|
|
||||||
SiteSetting.ai_openai_api_key.present? && has_url?(model_name)
|
|
||||||
end
|
|
||||||
|
|
||||||
def has_url?(model)
|
|
||||||
url =
|
|
||||||
if model.include?("gpt-4")
|
|
||||||
if model.include?("32k")
|
|
||||||
SiteSetting.ai_openai_gpt4_32k_url
|
|
||||||
else
|
|
||||||
if model.include?("1106") || model.include?("turbo")
|
|
||||||
SiteSetting.ai_openai_gpt4_turbo_url
|
|
||||||
elsif model.include?("gpt-4o")
|
|
||||||
SiteSetting.ai_openai_gpt4o_url
|
|
||||||
else
|
|
||||||
SiteSetting.ai_openai_gpt4_url
|
|
||||||
end
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if model.include?("16k")
|
|
||||||
SiteSetting.ai_openai_gpt35_16k_url
|
|
||||||
else
|
|
||||||
SiteSetting.ai_openai_gpt35_url
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
url.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"OpenAI - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -68,7 +20,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ model: model }
|
{ model: llm_model.name }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -78,28 +30,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
return URI(llm_model.url) if llm_model&.url
|
URI(llm_model.url)
|
||||||
|
|
||||||
url =
|
|
||||||
if model.include?("gpt-4")
|
|
||||||
if model.include?("32k")
|
|
||||||
SiteSetting.ai_openai_gpt4_32k_url
|
|
||||||
else
|
|
||||||
if model.include?("1106") || model.include?("turbo")
|
|
||||||
SiteSetting.ai_openai_gpt4_turbo_url
|
|
||||||
else
|
|
||||||
SiteSetting.ai_openai_gpt4_url
|
|
||||||
end
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if model.include?("16k")
|
|
||||||
SiteSetting.ai_openai_gpt35_16k_url
|
|
||||||
else
|
|
||||||
SiteSetting.ai_openai_gpt35_url
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
URI(url)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -110,7 +41,7 @@ module DiscourseAi
|
||||||
|
|
||||||
# Usage is not available in Azure yet.
|
# Usage is not available in Azure yet.
|
||||||
# We'll fallback to guess this using the tokenizer.
|
# We'll fallback to guess this using the tokenizer.
|
||||||
payload[:stream_options] = { include_usage: true } if model_uri.host.exclude?("azure")
|
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
||||||
end
|
end
|
||||||
|
|
||||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||||
|
@ -119,19 +50,16 @@ module DiscourseAi
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
api_key = llm_model.api_key
|
||||||
|
|
||||||
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
|
if llm_model.provider == "azure"
|
||||||
|
|
||||||
if model_uri.host.include?("azure")
|
|
||||||
headers["api-key"] = api_key
|
headers["api-key"] = api_key
|
||||||
else
|
else
|
||||||
headers["Authorization"] = "Bearer #{api_key}"
|
headers["Authorization"] = "Bearer #{api_key}"
|
||||||
|
org_id = llm_model.lookup_custom_param("organization")
|
||||||
|
headers["OpenAI-Organization"] = org_id if org_id.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
org_id =
|
|
||||||
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
|
|
||||||
headers["OpenAI-Organization"] = org_id if org_id.present?
|
|
||||||
|
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -4,22 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Vllm < Base
|
class Vllm < Base
|
||||||
class << self
|
def self.can_contact?(model_provider)
|
||||||
def can_contact?(endpoint_name)
|
model_provider == "vllm"
|
||||||
endpoint_name == "vllm"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dependant_setting_names
|
|
||||||
%w[ai_vllm_endpoint_srv ai_vllm_endpoint]
|
|
||||||
end
|
|
||||||
|
|
||||||
def correctly_configured?(_model_name)
|
|
||||||
SiteSetting.ai_vllm_endpoint_srv.present? || SiteSetting.ai_vllm_endpoint.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def endpoint_name(model_name)
|
|
||||||
"vLLM - #{model_name}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def normalize_model_params(model_params)
|
def normalize_model_params(model_params)
|
||||||
|
@ -34,7 +20,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ max_tokens: 2000, model: model }
|
{ max_tokens: 2000, model: llm_model.name }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -44,16 +30,13 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
if llm_model&.url && !llm_model&.url == LlmModel::RESERVED_VLLM_SRV_URL
|
if llm_model.url.to_s.starts_with?("srv://")
|
||||||
return URI(llm_model.url)
|
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
||||||
end
|
|
||||||
|
|
||||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
|
||||||
if service.present?
|
|
||||||
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
||||||
else
|
else
|
||||||
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
|
api_endpoint = llm_model.url
|
||||||
end
|
end
|
||||||
|
|
||||||
@uri ||= URI(api_endpoint)
|
@uri ||= URI(api_endpoint)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -64,8 +64,8 @@ module DiscourseAi
|
||||||
id: "open_ai",
|
id: "open_ai",
|
||||||
models: [
|
models: [
|
||||||
{ name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" },
|
{ name: "gpt-4o", tokens: 131_072, display_name: "GPT-4 Omni" },
|
||||||
|
{ name: "gpt-4o-mini", tokens: 131_072, display_name: "GPT-4 Omni Mini" },
|
||||||
{ name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" },
|
{ name: "gpt-4-turbo", tokens: 131_072, display_name: "GPT-4 Turbo" },
|
||||||
{ name: "gpt-3.5-turbo", tokens: 16_385, display_name: "GPT-3.5 Turbo" },
|
|
||||||
],
|
],
|
||||||
tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer,
|
tokenizer: DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||||
endpoint: "https://api.openai.com/v1/chat/completions",
|
endpoint: "https://api.openai.com/v1/chat/completions",
|
||||||
|
@ -89,41 +89,6 @@ module DiscourseAi
|
||||||
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
|
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def models_by_provider
|
|
||||||
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
|
|
||||||
# However, since they use the same URL/key settings, there's no reason to duplicate them.
|
|
||||||
@models_by_provider ||=
|
|
||||||
{
|
|
||||||
aws_bedrock: %w[
|
|
||||||
claude-instant-1
|
|
||||||
claude-2
|
|
||||||
claude-3-haiku
|
|
||||||
claude-3-sonnet
|
|
||||||
claude-3-opus
|
|
||||||
],
|
|
||||||
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
|
|
||||||
vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
|
|
||||||
hugging_face: %w[
|
|
||||||
mistralai/Mixtral-8x7B-Instruct-v0.1
|
|
||||||
mistralai/Mistral-7B-Instruct-v0.2
|
|
||||||
],
|
|
||||||
cohere: %w[command-light command command-r command-r-plus],
|
|
||||||
open_ai: %w[
|
|
||||||
gpt-3.5-turbo
|
|
||||||
gpt-4
|
|
||||||
gpt-3.5-turbo-16k
|
|
||||||
gpt-4-32k
|
|
||||||
gpt-4-turbo
|
|
||||||
gpt-4-vision-preview
|
|
||||||
gpt-4o
|
|
||||||
],
|
|
||||||
google: %w[gemini-pro gemini-1.5-pro gemini-1.5-flash],
|
|
||||||
}.tap do |h|
|
|
||||||
h[:ollama] = ["mistral"] if Rails.env.development?
|
|
||||||
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def valid_provider_models
|
def valid_provider_models
|
||||||
return @valid_provider_models if defined?(@valid_provider_models)
|
return @valid_provider_models if defined?(@valid_provider_models)
|
||||||
|
|
||||||
|
@ -151,61 +116,38 @@ module DiscourseAi
|
||||||
@prompts << prompt if @prompts
|
@prompts << prompt if @prompts
|
||||||
end
|
end
|
||||||
|
|
||||||
def proxy(model_name)
|
def proxy(model)
|
||||||
provider_and_model_name = model_name.split(":")
|
llm_model =
|
||||||
provider_name = provider_and_model_name.first
|
if model.is_a?(LlmModel)
|
||||||
model_name_without_prov = provider_and_model_name[1..].join
|
model
|
||||||
|
else
|
||||||
|
model_name_without_prov = model.split(":").last.to_i
|
||||||
|
|
||||||
# We are in the process of transitioning to always use objects here.
|
LlmModel.find_by(id: model_name_without_prov)
|
||||||
# We'll live with this hack for a while.
|
|
||||||
if provider_name == "custom"
|
|
||||||
llm_model = LlmModel.find(model_name_without_prov)
|
|
||||||
raise UNKNOWN_MODEL if !llm_model
|
|
||||||
return proxy_from_obj(llm_model)
|
|
||||||
end
|
|
||||||
|
|
||||||
dialect_klass =
|
|
||||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
|
|
||||||
|
|
||||||
if @canned_response
|
|
||||||
if @canned_llm && @canned_llm != model_name
|
|
||||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return new(dialect_klass, nil, model_name, gateway: @canned_response)
|
raise UNKNOWN_MODEL if llm_model.nil?
|
||||||
end
|
|
||||||
|
|
||||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
model_provider = llm_model.provider
|
||||||
|
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
|
||||||
new(dialect_klass, gateway_klass, model_name_without_prov)
|
|
||||||
end
|
|
||||||
|
|
||||||
def proxy_from_obj(llm_model)
|
|
||||||
provider_name = llm_model.provider
|
|
||||||
model_name = llm_model.name
|
|
||||||
|
|
||||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
|
||||||
|
|
||||||
if @canned_response
|
if @canned_response
|
||||||
if @canned_llm && @canned_llm != [provider_name, model_name].join(":")
|
if @canned_llm && @canned_llm != model
|
||||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model}"
|
||||||
end
|
end
|
||||||
|
|
||||||
return(
|
return new(dialect_klass, nil, llm_model, gateway: @canned_response)
|
||||||
new(dialect_klass, nil, model_name, gateway: @canned_response, llm_model: llm_model)
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
|
||||||
|
|
||||||
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
|
new(dialect_klass, gateway_klass, llm_model)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
|
def initialize(dialect_klass, gateway_klass, llm_model, gateway: nil)
|
||||||
@dialect_klass = dialect_klass
|
@dialect_klass = dialect_klass
|
||||||
@gateway_klass = gateway_klass
|
@gateway_klass = gateway_klass
|
||||||
@model_name = model_name
|
|
||||||
@gateway = gateway
|
@gateway = gateway
|
||||||
@llm_model = llm_model
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
@ -264,9 +206,9 @@ module DiscourseAi
|
||||||
|
|
||||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||||
|
|
||||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
|
dialect = dialect_klass.new(prompt, llm_model, opts: model_params)
|
||||||
|
|
||||||
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
|
gateway = @gateway || gateway_klass.new(llm_model)
|
||||||
gateway.perform_completion!(
|
gateway.perform_completion!(
|
||||||
dialect,
|
dialect,
|
||||||
user,
|
user,
|
||||||
|
@ -277,16 +219,14 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
llm_model&.max_prompt_tokens ||
|
llm_model.max_prompt_tokens
|
||||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
llm_model&.tokenizer_class ||
|
llm_model.tokenizer_class
|
||||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
|
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :model_name, :llm_model
|
attr_reader :llm_model
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
|
|
@ -15,19 +15,16 @@ module DiscourseAi
|
||||||
return !@parent_enabled
|
return !@parent_enabled
|
||||||
end
|
end
|
||||||
|
|
||||||
llm_model_id = val.split(":")&.last
|
run_test(val).tap { |result| @unreachable = result }
|
||||||
llm_model = LlmModel.find_by(id: llm_model_id)
|
rescue StandardError => e
|
||||||
return false if llm_model.nil?
|
raise e if Rails.env.test?
|
||||||
|
|
||||||
run_test(llm_model).tap { |result| @unreachable = result }
|
|
||||||
rescue StandardError
|
|
||||||
@unreachable = true
|
@unreachable = true
|
||||||
false
|
false
|
||||||
end
|
end
|
||||||
|
|
||||||
def run_test(llm_model)
|
def run_test(val)
|
||||||
DiscourseAi::Completions::Llm
|
DiscourseAi::Completions::Llm
|
||||||
.proxy_from_obj(llm_model)
|
.proxy(val)
|
||||||
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
|
.generate("How much is 1 + 1?", user: nil, feature_name: "llm_validator")
|
||||||
.present?
|
.present?
|
||||||
end
|
end
|
||||||
|
|
|
@ -80,8 +80,4 @@ after_initialize do
|
||||||
nil
|
nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
on(:site_setting_changed) do |name, _old_value, _new_value|
|
|
||||||
LlmModel.seed_srv_backed_model if name == :ai_vllm_endpoint_srv || name == :ai_vllm_api_key
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -5,5 +5,67 @@ Fabricator(:llm_model) do
|
||||||
name "gpt-4-turbo"
|
name "gpt-4-turbo"
|
||||||
provider "open_ai"
|
provider "open_ai"
|
||||||
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||||
|
api_key "123"
|
||||||
|
url "https://api.openai.com/v1/chat/completions"
|
||||||
|
max_prompt_tokens 131_072
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:anthropic_model, from: :llm_model) do
|
||||||
|
display_name "Claude 3 Opus"
|
||||||
|
name "claude-3-opus"
|
||||||
|
max_prompt_tokens 200_000
|
||||||
|
url "https://api.anthropic.com/v1/messages"
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::AnthropicTokenizer"
|
||||||
|
provider "anthropic"
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:hf_model, from: :llm_model) do
|
||||||
|
display_name "Llama 3.1"
|
||||||
|
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||||
|
max_prompt_tokens 64_000
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||||
|
url "https://test.dev/v1/chat/completions"
|
||||||
|
provider "hugging_face"
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:vllm_model, from: :llm_model) do
|
||||||
|
display_name "Llama 3.1 vLLM"
|
||||||
|
name "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||||
|
max_prompt_tokens 64_000
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||||
|
url "https://test.dev/v1/chat/completions"
|
||||||
|
provider "vllm"
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:fake_model, from: :llm_model) do
|
||||||
|
display_name "Fake model"
|
||||||
|
name "fake"
|
||||||
|
provider "fake"
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||||
max_prompt_tokens 32_000
|
max_prompt_tokens 32_000
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Fabricator(:gemini_model, from: :llm_model) do
|
||||||
|
display_name "Gemini"
|
||||||
|
name "gemini-1.5-pro"
|
||||||
|
provider "google"
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||||
|
max_prompt_tokens 800_000
|
||||||
|
url "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest"
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:bedrock_model, from: :anthropic_model) do
|
||||||
|
url ""
|
||||||
|
provider "aws_bedrock"
|
||||||
|
api_key "asd-asd-asd"
|
||||||
|
name "claude-3-sonnet"
|
||||||
|
provider_params { { region: "us-east-1", access_key_id: "123456" } }
|
||||||
|
end
|
||||||
|
|
||||||
|
Fabricator(:cohere_model, from: :llm_model) do
|
||||||
|
display_name "Cohere Command R+"
|
||||||
|
name "command-r-plus"
|
||||||
|
provider "cohere"
|
||||||
|
api_key "ABC"
|
||||||
|
url "https://api.cohere.ai/v1/chat"
|
||||||
|
end
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
require_relative "dialect_context"
|
require_relative "dialect_context"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
let(:model_name) { "gpt-4" }
|
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
|
||||||
let(:context) { DialectContext.new(described_class, model_name) }
|
let(:context) { DialectContext.new(described_class, llm_model) }
|
||||||
|
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
it "translates a prompt written in our generic format to the ChatGPT format" do
|
it "translates a prompt written in our generic format to the ChatGPT format" do
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
let :opus_dialect_klass do
|
let :opus_dialect_klass do
|
||||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
|
||||||
|
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
it "can insert OKs to make stuff interleve properly" do
|
it "can insert OKs to make stuff interleve properly" do
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -17,7 +19,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
|
|
||||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||||
|
|
||||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||||
translated = dialect.translate
|
translated = dialect.translate
|
||||||
|
|
||||||
expected_messages = [
|
expected_messages = [
|
||||||
|
@ -62,7 +64,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
tools: tools,
|
tools: tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||||
translated = dialect.translate
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||||
|
@ -114,7 +116,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
messages: messages,
|
messages: messages,
|
||||||
tools: tools,
|
tools: tools,
|
||||||
)
|
)
|
||||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
dialect = opus_dialect_klass.new(prompt, llm_model)
|
||||||
translated = dialect.translate
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
class DialectContext
|
class DialectContext
|
||||||
def initialize(dialect_klass, model_name)
|
def initialize(dialect_klass, llm_model)
|
||||||
@dialect_klass = dialect_klass
|
@dialect_klass = dialect_klass
|
||||||
@model_name = model_name
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
def dialect(prompt)
|
def dialect(prompt)
|
||||||
@dialect_klass.new(prompt, @model_name)
|
@dialect_klass.new(prompt, @llm_model)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prompt
|
def prompt
|
||||||
|
|
|
@ -13,6 +13,8 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||||
|
fab!(:llm_model)
|
||||||
|
|
||||||
describe "#trim_messages" do
|
describe "#trim_messages" do
|
||||||
let(:five_token_msg) { "This represents five tokens." }
|
let(:five_token_msg) { "This represents five tokens." }
|
||||||
|
|
||||||
|
@ -23,7 +25,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||||
prompt.push(type: :tool, content: five_token_msg, id: 1)
|
prompt.push(type: :tool, content: five_token_msg, id: 1)
|
||||||
prompt.push(type: :user, content: five_token_msg)
|
prompt.push(type: :user, content: five_token_msg)
|
||||||
|
|
||||||
dialect = TestDialect.new(prompt, "test")
|
dialect = TestDialect.new(prompt, llm_model)
|
||||||
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
|
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
|
||||||
|
|
||||||
trimmed = dialect.trim(prompt.messages)
|
trimmed = dialect.trim(prompt.messages)
|
||||||
|
@ -37,7 +39,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||||
prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens")
|
prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens")
|
||||||
prompt.push(type: :user, content: five_token_msg)
|
prompt.push(type: :user, content: five_token_msg)
|
||||||
|
|
||||||
dialect = TestDialect.new(prompt, "test")
|
dialect = TestDialect.new(prompt, llm_model)
|
||||||
dialect.max_prompt_tokens = 15
|
dialect.max_prompt_tokens = 15
|
||||||
|
|
||||||
trimmed = dialect.trim(prompt.messages)
|
trimmed = dialect.trim(prompt.messages)
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
require_relative "dialect_context"
|
require_relative "dialect_context"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
let(:model_name) { "gemini-1.5-pro" }
|
fab!(:model) { Fabricate(:gemini_model) }
|
||||||
let(:context) { DialectContext.new(described_class, model_name) }
|
let(:context) { DialectContext.new(described_class, model) }
|
||||||
|
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
it "translates a prompt written in our generic format to the Gemini format" do
|
it "translates a prompt written in our generic format to the Gemini format" do
|
||||||
|
@ -86,11 +86,12 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
|
|
||||||
it "trims content if it's getting too long" do
|
it "trims content if it's getting too long" do
|
||||||
# testing truncation on 800k tokens is slow use model with less
|
# testing truncation on 800k tokens is slow use model with less
|
||||||
context = DialectContext.new(described_class, "gemini-pro")
|
model.max_prompt_tokens = 16_384
|
||||||
|
context = DialectContext.new(described_class, model)
|
||||||
translated = context.long_user_input_scenario(length: 5_000)
|
translated = context.long_user_input_scenario(length: 5_000)
|
||||||
|
|
||||||
expect(translated[:messages].last[:role]).to eq("user")
|
expect(translated[:messages].last[:role]).to eq("user")
|
||||||
expect(translated[:messages].last.dig(:parts, :text).length).to be <
|
expect(translated[:messages].last.dig(:parts, 0, :text).length).to be <
|
||||||
context.long_message_text(length: 5_000).length
|
context.long_message_text(length: 5_000).length
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,16 +3,7 @@ require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
let(:url) { "https://api.anthropic.com/v1/messages" }
|
let(:url) { "https://api.anthropic.com/v1/messages" }
|
||||||
fab!(:model) do
|
fab!(:model) { Fabricate(:anthropic_model, name: "claude-3-opus", vision_enabled: true) }
|
||||||
Fabricate(
|
|
||||||
:llm_model,
|
|
||||||
url: "https://api.anthropic.com/v1/messages",
|
|
||||||
name: "claude-3-opus",
|
|
||||||
provider: "anthropic",
|
|
||||||
api_key: "123",
|
|
||||||
vision_enabled: true,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") }
|
||||||
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
let(:upload100x100) do
|
let(:upload100x100) do
|
||||||
|
@ -204,6 +195,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports non streaming tool calls" do
|
it "supports non streaming tool calls" do
|
||||||
|
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
|
||||||
|
|
||||||
tool = {
|
tool = {
|
||||||
name: "calculate",
|
name: "calculate",
|
||||||
description: "calculate something",
|
description: "calculate something",
|
||||||
|
@ -224,8 +217,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
tools: [tool],
|
tools: [tool],
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
|
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||||
type: "message",
|
type: "message",
|
||||||
|
@ -252,7 +243,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
|
|
||||||
stub_request(:post, url).to_return(body: body)
|
stub_request(:post, url).to_return(body: body)
|
||||||
|
|
||||||
result = proxy.generate(prompt, user: Discourse.system_user)
|
result = llm.generate(prompt, user: Discourse.system_user)
|
||||||
|
|
||||||
expected = <<~TEXT.strip
|
expected = <<~TEXT.strip
|
||||||
<function_calls>
|
<function_calls>
|
||||||
|
@ -370,7 +361,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body)
|
).to_return(status: 200, body: body)
|
||||||
|
|
||||||
result = llm.generate(prompt, user: Discourse.system_user)
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
result = proxy.generate(prompt, user: Discourse.system_user)
|
||||||
expect(result).to eq("Hello!")
|
expect(result).to eq("Hello!")
|
||||||
|
|
||||||
expected_body = {
|
expected_body = {
|
||||||
|
|
|
@ -8,9 +8,10 @@ class BedrockMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
subject(:endpoint) { described_class.new(model) }
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
fab!(:model) { Fabricate(:bedrock_model) }
|
||||||
|
|
||||||
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
||||||
|
|
||||||
|
@ -25,16 +26,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||||
end
|
end
|
||||||
|
|
||||||
before do
|
|
||||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
|
||||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
|
||||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "function calling" do
|
describe "function calling" do
|
||||||
it "supports old school xml function calls" do
|
it "supports old school xml function calls" do
|
||||||
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
incomplete_tool_call = <<~XML.strip
|
incomplete_tool_call = <<~XML.strip
|
||||||
<thinking>I should be ignored</thinking>
|
<thinking>I should be ignored</thinking>
|
||||||
|
@ -112,7 +107,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports streaming function calls" do
|
it "supports streaming function calls" do
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
request = nil
|
request = nil
|
||||||
|
|
||||||
|
@ -124,7 +119,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
model: "claude-3-haiku-20240307",
|
model: "claude-3-sonnet-20240307",
|
||||||
stop_sequence: nil,
|
stop_sequence: nil,
|
||||||
usage: {
|
usage: {
|
||||||
input_tokens: 840,
|
input_tokens: 840,
|
||||||
|
@ -281,9 +276,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "Claude 3 Sonnet support" do
|
describe "Claude 3 support" do
|
||||||
it "supports the sonnet model" do
|
it "supports regular completions" do
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
request = nil
|
request = nil
|
||||||
|
|
||||||
|
@ -325,8 +320,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
expect(log.response_tokens).to eq(20)
|
expect(log.response_tokens).to eq(20)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports claude 3 sonnet streaming" do
|
it "supports claude 3 streaming" do
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
request = nil
|
request = nil
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
require_relative "endpoint_compliance"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") }
|
fab!(:cohere_model)
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{cohere_model.id}") }
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
|
@ -57,8 +58,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
|
||||||
|
|
||||||
it "is able to trigger a tool" do
|
it "is able to trigger a tool" do
|
||||||
body = (<<~TEXT).strip
|
body = (<<~TEXT).strip
|
||||||
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
||||||
|
@ -184,7 +183,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
expect(audit.request_tokens).to eq(17)
|
expect(audit.request_tokens).to eq(17)
|
||||||
expect(audit.response_tokens).to eq(22)
|
expect(audit.response_tokens).to eq(22)
|
||||||
|
|
||||||
expect(audit.language_model).to eq("Cohere - command-r-plus")
|
expect(audit.language_model).to eq("command-r-plus")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "is able to perform streaming completions" do
|
it "is able to perform streaming completions" do
|
||||||
|
|
|
@ -158,7 +158,7 @@ class EndpointsCompliance
|
||||||
end
|
end
|
||||||
|
|
||||||
def dialect(prompt: generic_prompt)
|
def dialect(prompt: generic_prompt)
|
||||||
dialect_klass.new(prompt, endpoint.model)
|
dialect_klass.new(prompt, endpoint.llm_model)
|
||||||
end
|
end
|
||||||
|
|
||||||
def regular_mode_simple_prompt(mock)
|
def regular_mode_simple_prompt(mock)
|
||||||
|
@ -176,7 +176,7 @@ class EndpointsCompliance
|
||||||
expect(log.raw_request_payload).to be_present
|
expect(log.raw_request_payload).to be_present
|
||||||
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
|
expect(log.raw_response_payload).to eq(mock.response(completion_response).to_json)
|
||||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||||
expect(log.response_tokens).to eq(endpoint.tokenizer.size(completion_response))
|
expect(log.response_tokens).to eq(endpoint.llm_model.tokenizer_class.size(completion_response))
|
||||||
end
|
end
|
||||||
|
|
||||||
def regular_mode_tools(mock)
|
def regular_mode_tools(mock)
|
||||||
|
@ -206,7 +206,7 @@ class EndpointsCompliance
|
||||||
expect(log.raw_response_payload).to be_present
|
expect(log.raw_response_payload).to be_present
|
||||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||||
expect(log.response_tokens).to eq(
|
expect(log.response_tokens).to eq(
|
||||||
endpoint.tokenizer.size(mock.streamed_simple_deltas[0...-1].join),
|
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -128,18 +128,9 @@ class GeminiMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
subject(:endpoint) { described_class.new("gemini-pro", DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
subject(:endpoint) { described_class.new(model) }
|
||||||
|
|
||||||
fab!(:model) do
|
fab!(:model) { Fabricate(:gemini_model, vision_enabled: true) }
|
||||||
Fabricate(
|
|
||||||
:llm_model,
|
|
||||||
url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest",
|
|
||||||
name: "gemini-1.5-pro",
|
|
||||||
provider: "google",
|
|
||||||
api_key: "ABC",
|
|
||||||
vision_enabled: true,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
|
@ -168,7 +159,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
req_body = nil
|
req_body = nil
|
||||||
|
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
url = "#{model.url}:generateContent?key=ABC"
|
url = "#{model.url}:generateContent?key=123"
|
||||||
|
|
||||||
stub_request(:post, url).with(
|
stub_request(:post, url).with(
|
||||||
body:
|
body:
|
||||||
|
@ -221,7 +212,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
split = data.split("|")
|
split = data.split("|")
|
||||||
|
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=ABC"
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||||
|
|
||||||
output = +""
|
output = +""
|
||||||
gemini_mock.with_chunk_array_support do
|
gemini_mock.with_chunk_array_support do
|
||||||
|
|
|
@ -22,7 +22,7 @@ class HuggingFaceMock < EndpointMock
|
||||||
|
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||||
.with(body: request_body(prompt))
|
.with(body: request_body(prompt))
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
@ -40,7 +40,7 @@ class HuggingFaceMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_raw(chunks)
|
def stub_raw(chunks)
|
||||||
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
|
WebMock.stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
||||||
status: 200,
|
status: 200,
|
||||||
body: chunks,
|
body: chunks,
|
||||||
)
|
)
|
||||||
|
@ -59,7 +59,7 @@ class HuggingFaceMock < EndpointMock
|
||||||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||||
.with(body: request_body(prompt, stream: true))
|
.with(body: request_body(prompt, stream: true))
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
|
|
||||||
|
@ -71,8 +71,7 @@ class HuggingFaceMock < EndpointMock
|
||||||
.default_options
|
.default_options
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap do |b|
|
.tap do |b|
|
||||||
b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
b[:max_tokens] = 63_991
|
||||||
model.prompt_size(prompt)
|
|
||||||
b[:stream] = true if stream
|
b[:stream] = true if stream
|
||||||
end
|
end
|
||||||
.to_json
|
.to_json
|
||||||
|
@ -80,15 +79,9 @@ class HuggingFaceMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
subject(:endpoint) do
|
subject(:endpoint) { described_class.new(hf_model) }
|
||||||
described_class.new(
|
|
||||||
"mistralai/Mistral-7B-Instruct-v0.2",
|
|
||||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
|
||||||
|
|
||||||
|
fab!(:hf_model)
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
|
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
|
||||||
|
|
|
@ -146,11 +146,10 @@ class OpenAiMock < EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
subject(:endpoint) do
|
subject(:endpoint) { described_class.new(model) }
|
||||||
described_class.new("gpt-3.5-turbo", DiscourseAi::Tokenizer::OpenAiTokenizer)
|
|
||||||
end
|
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
fab!(:model) { Fabricate(:llm_model) }
|
||||||
|
|
||||||
let(:echo_tool) do
|
let(:echo_tool) do
|
||||||
{
|
{
|
||||||
|
@ -175,7 +174,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
describe "repeat calls" do
|
describe "repeat calls" do
|
||||||
it "can properly reset context" do
|
it "can properly reset context" do
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo")
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -258,7 +257,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
describe "image support" do
|
describe "image support" do
|
||||||
it "can handle images" do
|
it "can handle images" do
|
||||||
model = Fabricate(:llm_model, provider: "open_ai", vision_enabled: true)
|
model = Fabricate(:llm_model, vision_enabled: true)
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
prompt =
|
prompt =
|
||||||
DiscourseAi::Completions::Prompt.new(
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
|
|
@ -22,7 +22,7 @@ class VllmMock < EndpointMock
|
||||||
|
|
||||||
def stub_response(prompt, response_text, tool_call: false)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||||
.with(body: model.default_options.merge(messages: prompt).to_json)
|
.with(body: model.default_options.merge(messages: prompt).to_json)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
@ -50,19 +50,16 @@ class VllmMock < EndpointMock
|
||||||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
subject(:endpoint) do
|
subject(:endpoint) { described_class.new(llm_model) }
|
||||||
described_class.new(
|
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
fab!(:llm_model) { Fabricate(:vllm_model) }
|
||||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
||||||
|
@ -78,15 +75,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:dialect) do
|
let(:dialect) do
|
||||||
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, model_name)
|
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, llm_model)
|
||||||
end
|
end
|
||||||
let(:prompt) { dialect.translate }
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
||||||
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
||||||
|
|
||||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
|
||||||
|
|
||||||
describe "#perform_completion!" do
|
describe "#perform_completion!" do
|
||||||
context "when using regular mode" do
|
context "when using regular mode" do
|
||||||
context "with simple prompts" do
|
context "with simple prompts" do
|
||||||
|
|
|
@ -5,12 +5,13 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||||
canned_response,
|
canned_response,
|
||||||
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
model,
|
||||||
gateway: canned_response,
|
gateway: canned_response,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
fab!(:model) { Fabricate(:llm_model) }
|
||||||
|
|
||||||
describe ".proxy" do
|
describe ".proxy" do
|
||||||
it "raises an exception when we can't proxy the model" do
|
it "raises an exception when we can't proxy the model" do
|
||||||
|
@ -46,7 +47,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
)
|
)
|
||||||
result = +""
|
result = +""
|
||||||
described_class
|
described_class
|
||||||
.proxy("open_ai:gpt-3.5-turbo")
|
.proxy("custom:#{model.id}")
|
||||||
.generate(prompt, user: user) { |partial| result << partial }
|
.generate(prompt, user: user) { |partial| result << partial }
|
||||||
|
|
||||||
expect(result).to eq("Hello")
|
expect(result).to eq("Hello")
|
||||||
|
@ -57,12 +58,14 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#generate with fake model" do
|
describe "#generate with fake model" do
|
||||||
|
fab!(:fake_model)
|
||||||
|
|
||||||
before do
|
before do
|
||||||
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
||||||
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
|
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:llm) { described_class.proxy("fake:fake") }
|
let(:llm) { described_class.proxy("custom:#{fake_model.id}") }
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
DiscourseAi::Completions::Prompt.new(
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
|
|
@ -5,6 +5,8 @@ return if !defined?(DiscourseAutomation)
|
||||||
describe DiscourseAutomation do
|
describe DiscourseAutomation do
|
||||||
let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) }
|
let(:automation) { Fabricate(:automation, script: "llm_report", enabled: true) }
|
||||||
|
|
||||||
|
fab!(:llm_model)
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
fab!(:post)
|
fab!(:post)
|
||||||
|
|
||||||
|
@ -22,7 +24,7 @@ describe DiscourseAutomation do
|
||||||
it "can trigger via automation" do
|
it "can trigger via automation" do
|
||||||
add_automation_field("sender", user.username, type: "user")
|
add_automation_field("sender", user.username, type: "user")
|
||||||
add_automation_field("receivers", [user.username], type: "users")
|
add_automation_field("receivers", [user.username], type: "users")
|
||||||
add_automation_field("model", "gpt-4-turbo")
|
add_automation_field("model", "custom:#{llm_model.id}")
|
||||||
add_automation_field("title", "Weekly report")
|
add_automation_field("title", "Weekly report")
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
||||||
|
@ -36,7 +38,7 @@ describe DiscourseAutomation do
|
||||||
it "can target a topic" do
|
it "can target a topic" do
|
||||||
add_automation_field("sender", user.username, type: "user")
|
add_automation_field("sender", user.username, type: "user")
|
||||||
add_automation_field("topic_id", "#{post.topic_id}")
|
add_automation_field("topic_id", "#{post.topic_id}")
|
||||||
add_automation_field("model", "gpt-4-turbo")
|
add_automation_field("model", "custom:#{llm_model.id}")
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["An Amazing Report!!!"]) do
|
||||||
automation.trigger!
|
automation.trigger!
|
||||||
|
|
|
@ -8,6 +8,8 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
|
|
||||||
let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) }
|
let(:automation) { Fabricate(:automation, script: "llm_triage", enabled: true) }
|
||||||
|
|
||||||
|
fab!(:llm_model)
|
||||||
|
|
||||||
def add_automation_field(name, value, type: "text")
|
def add_automation_field(name, value, type: "text")
|
||||||
automation.fields.create!(
|
automation.fields.create!(
|
||||||
component: type,
|
component: type,
|
||||||
|
@ -23,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
SiteSetting.tagging_enabled = true
|
SiteSetting.tagging_enabled = true
|
||||||
add_automation_field("system_prompt", "hello %%POST%%")
|
add_automation_field("system_prompt", "hello %%POST%%")
|
||||||
add_automation_field("search_for_text", "bad")
|
add_automation_field("search_for_text", "bad")
|
||||||
add_automation_field("model", "gpt-4")
|
add_automation_field("model", "custom:#{llm_model.id}")
|
||||||
add_automation_field("category", category.id, type: "category")
|
add_automation_field("category", category.id, type: "category")
|
||||||
add_automation_field("tags", %w[aaa bbb], type: "tags")
|
add_automation_field("tags", %w[aaa bbb], type: "tags")
|
||||||
add_automation_field("hide_topic", true, type: "boolean")
|
add_automation_field("hide_topic", true, type: "boolean")
|
||||||
|
|
|
@ -12,7 +12,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-4") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_4.name) }
|
||||||
|
|
||||||
let!(:user) { Fabricate(:user) }
|
let!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||||
toggle_enabled_bots(bots: [fake])
|
toggle_enabled_bots(bots: [fake])
|
||||||
Group.refresh_automatic_groups!
|
Group.refresh_automatic_groups!
|
||||||
|
|
||||||
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model("fake")
|
bot_user = DiscourseAi::AiBot::EntryPoint.find_user_from_model(fake.name)
|
||||||
AiPersona.create!(
|
AiPersona.create!(
|
||||||
name: "TestPersona",
|
name: "TestPersona",
|
||||||
top_p: 0.5,
|
top_p: 0.5,
|
||||||
|
|
|
@ -336,6 +336,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
context "when RAG is running with a question consolidator" do
|
context "when RAG is running with a question consolidator" do
|
||||||
let(:consolidated_question) { "what is the time in france?" }
|
let(:consolidated_question) { "what is the time in france?" }
|
||||||
|
|
||||||
|
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||||
|
|
||||||
it "will run the question consolidator" do
|
it "will run the question consolidator" do
|
||||||
context_embedding = [0.049382, 0.9999]
|
context_embedding = [0.049382, 0.9999]
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
EmbeddingsGenerationStubs.discourse_service(
|
||||||
|
@ -350,7 +352,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
name: "custom",
|
name: "custom",
|
||||||
rag_conversation_chunks: 3,
|
rag_conversation_chunks: 3,
|
||||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||||
question_consolidator_llm: "fake:fake",
|
question_consolidator_llm: "custom:#{llm_model.id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
|
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
|
||||||
|
|
|
@ -4,6 +4,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
subject(:playground) { described_class.new(bot) }
|
subject(:playground) { described_class.new(bot) }
|
||||||
|
|
||||||
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") }
|
fab!(:claude_2) { Fabricate(:llm_model, name: "claude-2") }
|
||||||
|
fab!(:opus_model) { Fabricate(:anthropic_model) }
|
||||||
|
|
||||||
fab!(:bot_user) do
|
fab!(:bot_user) do
|
||||||
toggle_enabled_bots(bots: [claude_2])
|
toggle_enabled_bots(bots: [claude_2])
|
||||||
|
@ -160,7 +161,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
system_prompt: "You are a helpful bot",
|
system_prompt: "You are a helpful bot",
|
||||||
vision_enabled: true,
|
vision_enabled: true,
|
||||||
vision_max_pixels: 1_000,
|
vision_max_pixels: 1_000,
|
||||||
default_llm: "anthropic:claude-3-opus",
|
default_llm: "custom:#{opus_model.id}",
|
||||||
mentionable: true,
|
mentionable: true,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
@ -211,7 +212,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
)
|
)
|
||||||
|
|
||||||
persona.create_user!
|
persona.create_user!
|
||||||
persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
|
persona.update!(default_llm: "custom:#{claude_2.id}", mentionable: true)
|
||||||
persona
|
persona
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -228,7 +229,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
|
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
|
||||||
Group.refresh_automatic_groups!
|
Group.refresh_automatic_groups!
|
||||||
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
|
persona.update!(allow_chat: true, mentionable: true, default_llm: "custom:#{opus_model.id}")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "should behave in a sane way when threading is enabled" do
|
it "should behave in a sane way when threading is enabled" do
|
||||||
|
@ -342,7 +343,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
persona.update!(
|
persona.update!(
|
||||||
allow_chat: true,
|
allow_chat: true,
|
||||||
mentionable: false,
|
mentionable: false,
|
||||||
default_llm: "anthropic:claude-3-opus",
|
default_llm: "custom:#{opus_model.id}",
|
||||||
)
|
)
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
end
|
end
|
||||||
|
@ -517,7 +518,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||||
["Magic title", "Yes I can"],
|
["Magic title", "Yes I can"],
|
||||||
llm: "anthropic:claude-2",
|
llm: "custom:#{claude_2.id}",
|
||||||
) do
|
) do
|
||||||
post =
|
post =
|
||||||
create_post(
|
create_post(
|
||||||
|
@ -552,7 +553,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
# title is queued first, ensures it uses the llm targeted via target_usernames not claude
|
# title is queued first, ensures it uses the llm targeted via target_usernames not claude
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||||
["Magic title", "Yes I can"],
|
["Magic title", "Yes I can"],
|
||||||
llm: "open_ai:gpt-3.5-turbo",
|
llm: "custom:#{gpt_35_turbo.id}",
|
||||||
) do
|
) do
|
||||||
post =
|
post =
|
||||||
create_post(
|
create_post(
|
||||||
|
@ -584,7 +585,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
# replies as correct persona if replying direct to persona
|
# replies as correct persona if replying direct to persona
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||||
["Another reply"],
|
["Another reply"],
|
||||||
llm: "open_ai:gpt-3.5-turbo",
|
llm: "custom:#{gpt_35_turbo.id}",
|
||||||
) do
|
) do
|
||||||
create_post(
|
create_post(
|
||||||
raw: "Please ignore this bot, I am replying to a user",
|
raw: "Please ignore this bot, I am replying to a user",
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
|
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{Fabricate(:fake_model).id}") }
|
||||||
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
|
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
|
||||||
|
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
|
|
|
@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
|
||||||
SiteSetting.ai_openai_api_key = "abc"
|
SiteSetting.ai_openai_api_key = "abc"
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
let(:dall_e) do
|
let(:dall_e) do
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
RSpec.describe DiscourseAi::AiBot::Tools::DbSchema do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
RSpec.describe DiscourseAi::AiBot::Tools::DiscourseMetaSearch do
|
||||||
before do
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
SiteSetting.ai_bot_enabled = true
|
|
||||||
SiteSetting.ai_openai_api_key = "asd"
|
|
||||||
end
|
|
||||||
|
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model) { Fabricate(:llm_model, max_prompt_tokens: 8192) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
|
let(:mock_search_json) { plugin_file_from_fixtures("search.json", "search_meta").read }
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubFileContent do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
fab!(:llm_model)
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
let(:tool) do
|
let(:tool) do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
|
|
|
@ -4,7 +4,8 @@ require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubPullRequestDiff do
|
||||||
let(:bot_user) { Fabricate(:user) }
|
let(:bot_user) { Fabricate(:user) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
fab!(:llm_model)
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
|
let(:tool) { described_class.new({ repo: repo, pull_id: pull_id }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
context "with #sort_and_shorten_diff" do
|
context "with #sort_and_shorten_diff" do
|
||||||
|
|
|
@ -4,7 +4,8 @@ require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchCode do
|
||||||
let(:bot_user) { Fabricate(:user) }
|
let(:bot_user) { Fabricate(:user) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
fab!(:llm_model)
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
|
let(:tool) { described_class.new({ repo: repo, query: query }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
context "with valid search results" do
|
context "with valid search results" do
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
require "rails_helper"
|
require "rails_helper"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
|
RSpec.describe DiscourseAi::AiBot::Tools::GithubSearchFiles do
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4") }
|
fab!(:llm_model)
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
let(:tool) do
|
let(:tool) do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
RSpec.describe DiscourseAi::AiBot::Tools::Google do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }
|
let(:search) { described_class.new({ query: "some search term" }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Image do
|
||||||
|
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{gpt_35_turbo.id}") }
|
||||||
|
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(gpt_35_turbo.name) }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do
|
RSpec.describe DiscourseAi::AiBot::Tools::JavascriptEvaluator do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
RSpec.describe DiscourseAi::AiBot::Tools::ListCategories do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
RSpec.describe DiscourseAi::AiBot::Tools::ListTags do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
RSpec.describe DiscourseAi::AiBot::Tools::Read do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
|
let(:tool) { described_class.new({ topic_id: topic_with_tags.id }, bot_user: bot_user, llm: llm) }
|
||||||
|
|
||||||
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
fab!(:parent_category) { Fabricate(:category, name: "animals") }
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
RSpec.describe DiscourseAi::AiBot::Tools::SearchSettings do
|
||||||
fab!(:gpt_35_bot) { Fabricate(:llm_model, name: "gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_bot_enabled = true
|
SiteSetting.ai_bot_enabled = true
|
||||||
toggle_enabled_bots(bots: [gpt_35_bot])
|
toggle_enabled_bots(bots: [llm_model])
|
||||||
end
|
end
|
||||||
|
|
||||||
def search_settings(query)
|
def search_settings(query)
|
||||||
|
|
|
@ -4,10 +4,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
||||||
before { SearchIndexer.enable }
|
before { SearchIndexer.enable }
|
||||||
after { SearchIndexer.disable }
|
after { SearchIndexer.disable }
|
||||||
|
|
||||||
before { SiteSetting.ai_openai_api_key = "asd" }
|
fab!(:llm_model)
|
||||||
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
fab!(:admin)
|
fab!(:admin)
|
||||||
|
|
|
@ -9,8 +9,10 @@ def has_rg?
|
||||||
end
|
end
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
RSpec.describe DiscourseAi::AiBot::Tools::SettingContext, if: has_rg? do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
|
||||||
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
RSpec.describe DiscourseAi::AiBot::Tools::Summarize do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
let(:progress_blk) { Proc.new {} }
|
let(:progress_blk) { Proc.new {} }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
RSpec.describe DiscourseAi::AiBot::Tools::Time do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-3.5-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before { SiteSetting.ai_bot_enabled = true }
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
RSpec.describe DiscourseAi::AiBot::Tools::WebBrowser do
|
||||||
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model("gpt-3.5-turbo") }
|
fab!(:llm_model)
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("open_ai:gpt-4-turbo") }
|
let(:bot_user) { DiscourseAi::AiBot::EntryPoint.find_user_from_model(llm_model.name) }
|
||||||
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
before do
|
before { SiteSetting.ai_bot_enabled = true }
|
||||||
SiteSetting.ai_openai_api_key = "asd"
|
|
||||||
SiteSetting.ai_bot_enabled = true
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "#invoke" do
|
describe "#invoke" do
|
||||||
it "can retrieve the content of a webpage and returns the processed text" do
|
it "can retrieve the content of a webpage and returns the processed text" do
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
describe DiscourseAi::Automation::LlmTriage do
|
describe DiscourseAi::Automation::LlmTriage do
|
||||||
fab!(:post)
|
fab!(:post)
|
||||||
|
fab!(:llm_model)
|
||||||
|
|
||||||
def triage(**args)
|
def triage(**args)
|
||||||
DiscourseAi::Automation::LlmTriage.handle(**args)
|
DiscourseAi::Automation::LlmTriage.handle(**args)
|
||||||
|
@ -10,7 +11,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
hide_topic: true,
|
hide_topic: true,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -24,7 +25,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
hide_topic: true,
|
hide_topic: true,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -40,7 +41,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_id: category.id,
|
category_id: category.id,
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
|
@ -55,7 +56,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
canned_reply: "test canned reply 123",
|
canned_reply: "test canned reply 123",
|
||||||
|
@ -73,7 +74,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
flag_post: true,
|
flag_post: true,
|
||||||
|
@ -89,7 +90,7 @@ describe DiscourseAi::Automation::LlmTriage do
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
|
||||||
triage(
|
triage(
|
||||||
post: post,
|
post: post,
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
system_prompt: "test %%POST%%",
|
system_prompt: "test %%POST%%",
|
||||||
search_for_text: "bad",
|
search_for_text: "bad",
|
||||||
flag_post: true,
|
flag_post: true,
|
||||||
|
|
|
@ -32,6 +32,8 @@ module DiscourseAi
|
||||||
fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) }
|
fab!(:topic_with_tag) { Fabricate(:topic, tags: [tag, hidden_tag]) }
|
||||||
fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) }
|
fab!(:post_with_tag) { Fabricate(:post, raw: "I am in a tag", topic: topic_with_tag) }
|
||||||
|
|
||||||
|
fab!(:llm_model)
|
||||||
|
|
||||||
describe "#run!" do
|
describe "#run!" do
|
||||||
it "is able to generate email reports" do
|
it "is able to generate email reports" do
|
||||||
freeze_time
|
freeze_time
|
||||||
|
@ -41,7 +43,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: ["fake@discourse.com"],
|
receivers: ["fake@discourse.com"],
|
||||||
title: "test report %DATE%",
|
title: "test report %DATE%",
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
@ -78,7 +80,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: [receiver.username],
|
receivers: [receiver.username],
|
||||||
title: "test report",
|
title: "test report",
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
@ -123,7 +125,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: [receiver.username],
|
receivers: [receiver.username],
|
||||||
title: "test report",
|
title: "test report",
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
@ -166,7 +168,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: [receiver.username],
|
receivers: [receiver.username],
|
||||||
title: "test report",
|
title: "test report",
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
@ -194,7 +196,7 @@ module DiscourseAi
|
||||||
sender_username: user.username,
|
sender_username: user.username,
|
||||||
receivers: [receiver.username],
|
receivers: [receiver.username],
|
||||||
title: "test report",
|
title: "test report",
|
||||||
model: "gpt-4",
|
model: "custom:#{llm_model.id}",
|
||||||
category_ids: nil,
|
category_ids: nil,
|
||||||
tags: nil,
|
tags: nil,
|
||||||
allow_secure_categories: false,
|
allow_secure_categories: false,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
RSpec.describe AiTool do
|
RSpec.describe AiTool do
|
||||||
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
|
fab!(:llm_model) { Fabricate(:llm_model, name: "claude-2") }
|
||||||
let(:llm) { DiscourseAi::Completions::Llm.proxy_from_obj(llm_model) }
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||||
|
|
||||||
def create_tool(parameters: nil, script: nil)
|
def create_tool(parameters: nil, script: nil)
|
||||||
AiTool.create!(
|
AiTool.create!(
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
RSpec.describe LlmModel do
|
|
||||||
describe ".seed_srv_backed_model" do
|
|
||||||
before do
|
|
||||||
SiteSetting.ai_vllm_endpoint_srv = "srv.llm.service."
|
|
||||||
SiteSetting.ai_vllm_api_key = "123"
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the model doesn't exist yet" do
|
|
||||||
it "creates it" do
|
|
||||||
described_class.seed_srv_backed_model
|
|
||||||
|
|
||||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
|
||||||
|
|
||||||
expect(llm_model).to be_present
|
|
||||||
expect(llm_model.name).to eq("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
|
||||||
expect(llm_model.api_key).to eq(SiteSetting.ai_vllm_api_key)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the model already exists" do
|
|
||||||
before { described_class.seed_srv_backed_model }
|
|
||||||
|
|
||||||
context "when the API key setting changes" do
|
|
||||||
it "updates it" do
|
|
||||||
new_key = "456"
|
|
||||||
SiteSetting.ai_vllm_api_key = new_key
|
|
||||||
|
|
||||||
described_class.seed_srv_backed_model
|
|
||||||
|
|
||||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
|
||||||
|
|
||||||
expect(llm_model.api_key).to eq(new_key)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the SRV is no longer defined" do
|
|
||||||
it "deletes the LlmModel" do
|
|
||||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
|
||||||
expect(llm_model).to be_present
|
|
||||||
|
|
||||||
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
|
|
||||||
|
|
||||||
expect { llm_model.reload }.to raise_exception(ActiveRecord::RecordNotFound)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "disabled the bot user" do
|
|
||||||
SiteSetting.ai_bot_enabled = true
|
|
||||||
llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL)
|
|
||||||
llm_model.update!(enabled_chat_bot: true)
|
|
||||||
llm_model.toggle_companion_user
|
|
||||||
user = llm_model.user
|
|
||||||
|
|
||||||
expect(user).to be_present
|
|
||||||
|
|
||||||
SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code
|
|
||||||
|
|
||||||
expect { user.reload }.to raise_exception(ActiveRecord::RecordNotFound)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -8,7 +8,7 @@ module DiscourseAi::ChatBotHelper
|
||||||
end
|
end
|
||||||
|
|
||||||
def assign_fake_provider_to(setting_name)
|
def assign_fake_provider_to(setting_name)
|
||||||
Fabricate(:llm_model, provider: "fake", name: "fake").tap do |fake_llm|
|
Fabricate(:fake_model).tap do |fake_llm|
|
||||||
SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}")
|
SiteSetting.public_send("#{setting_name}=", "custom:#{fake_llm.id}")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue