FEATURE: Configurable LLMs. (#606)

This PR introduces the concept of "LlmModel" as a new way to quickly add new LLM models without making any code changes. We are releasing this first version and will add incremental improvements, so expect changes.

The AI Bot can't fully take advantage of this feature as users are hard-coded. We'll fix this in a separate PR.s
This commit is contained in:
Roman Rizzi 2024-05-13 12:46:42 -03:00 committed by GitHub
parent 5c02b885ea
commit 62fc7d6ed0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 656 additions and 70 deletions

View File

@ -0,0 +1,16 @@
import DiscourseRoute from "discourse/routes/discourse";
export default DiscourseRoute.extend({
async model() {
const record = this.store.createRecord("ai-llm");
return record;
},
setupController(controller, model) {
this._super(controller, model);
controller.set(
"allLlms",
this.modelFor("adminPlugins.show.discourse-ai-llms")
);
},
});

View File

@ -0,0 +1,17 @@
import DiscourseRoute from "discourse/routes/discourse";
export default DiscourseRoute.extend({
async model(params) {
const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms");
const id = parseInt(params.id, 10);
return allLlms.findBy("id", id);
},
setupController(controller, model) {
this._super(controller, model);
controller.set(
"allLlms",
this.modelFor("adminPlugins.show.discourse-ai-llms")
);
},
});

View File

@ -0,0 +1,7 @@
import DiscourseRoute from "discourse/routes/discourse";
export default class DiscourseAiAiLlmsRoute extends DiscourseRoute {
model() {
return this.store.findAll("ai-llm");
}
}

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.model}} />

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />

View File

@ -0,0 +1,64 @@
# frozen_string_literal: true
module DiscourseAi
module Admin
class AiLlmsController < ::Admin::AdminController
requires_plugin ::DiscourseAi::PLUGIN_NAME
def index
llms = LlmModel.all
render json: {
ai_llms:
ActiveModel::ArraySerializer.new(
llms,
each_serializer: LlmModelSerializer,
root: false,
).as_json,
meta: {
providers: DiscourseAi::Completions::Llm.provider_names,
tokenizers:
DiscourseAi::Completions::Llm.tokenizer_names.map { |tn|
{ id: tn, name: tn.split("::").last }
},
},
}
end
def show
llm_model = LlmModel.find(params[:id])
render json: LlmModelSerializer.new(llm_model)
end
def create
if llm_model = LlmModel.new(ai_llm_params).save
render json: { ai_persona: llm_model }, status: :created
else
render_json_error llm_model
end
end
def update
llm_model = LlmModel.find(params[:id])
if llm_model.update(ai_llm_params)
render json: llm_model
else
render_json_error llm_model
end
end
private
def ai_llm_params
params.require(:ai_llm).permit(
:display_name,
:name,
:provider,
:tokenizer,
:max_prompt_tokens,
)
end
end
end
end

7
app/models/llm_model.rb Normal file
View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class LlmModel < ActiveRecord::Base
def tokenizer_class
tokenizer.constantize
end
end

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class LlmModelSerializer < ApplicationSerializer
root "llm"
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
end

View File

@ -8,5 +8,10 @@ export default {
this.route("new"); this.route("new");
this.route("show", { path: "/:id" }); this.route("show", { path: "/:id" });
}); });
this.route("discourse-ai-llms", { path: "ai-llms" }, function () {
this.route("new");
this.route("show", { path: "/:id" });
});
}, },
}; };

View File

@ -0,0 +1,21 @@
import RestAdapter from "discourse/adapters/rest";
export default class Adapter extends RestAdapter {
jsonMode = true;
basePath() {
return "/admin/plugins/discourse-ai/";
}
pathFor(store, type, findArgs) {
// removes underscores which are implemented in base
let path =
this.basePath(store, type, findArgs) +
store.pluralize(this.apiNameFor(type));
return this.appendQueryParams(path, findArgs);
}
apiNameFor() {
return "ai-llm";
}
}

View File

@ -0,0 +1,20 @@
import RestModel from "discourse/models/rest";
export default class AiLlm extends RestModel {
createProperties() {
return this.getProperties(
"display_name",
"name",
"provider",
"tokenizer",
"max_prompt_tokens"
);
}
updateProperties() {
const attrs = this.createProperties();
attrs.id = this.id;
return attrs;
}
}

View File

@ -0,0 +1,122 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { Input } from "@ember/component";
import { action } from "@ember/object";
import { later } from "@ember/runloop";
import { inject as service } from "@ember/service";
import DButton from "discourse/components/d-button";
import { popupAjaxError } from "discourse/lib/ajax-error";
import i18n from "discourse-common/helpers/i18n";
import I18n from "discourse-i18n";
import ComboBox from "select-kit/components/combo-box";
import DTooltip from "float-kit/components/d-tooltip";
export default class AiLlmEditor extends Component {
@service toasts;
@service router;
@tracked isSaving = false;
get selectedProviders() {
const t = (provName) => {
return I18n.t(`discourse_ai.llms.providers.${provName}`);
};
return this.args.llms.resultSetMeta.providers.map((prov) => {
return { id: prov, name: t(prov) };
});
}
@action
async save() {
this.isSaving = true;
const isNew = this.args.model.isNew;
try {
await this.args.model.save();
if (isNew) {
this.args.llms.addObject(this.args.model);
this.router.transitionTo(
"adminPlugins.show.discourse-ai-llms.show",
this.args.model
);
} else {
this.toasts.success({
data: { message: I18n.t("discourse_ai.llms.saved") },
duration: 2000,
});
}
} catch (e) {
popupAjaxError(e);
} finally {
later(() => {
this.isSaving = false;
}, 1000);
}
}
<template>
<form class="form-horizontal ai-llm-editor">
<div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
<Input
class="ai-llm-editor__display-name"
@type="text"
@value={{@model.display_name}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.llms.name"}}</label>
<Input
class="ai-llm-editor__name"
@type="text"
@value={{@model.name}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.llms.hints.name"}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.provider"}}</label>
<ComboBox
@value={{@model.provider}}
@content={{this.selectedProviders}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox
@value={{@model.tokenizer}}
@content={{@llms.resultSetMeta.tokenizers}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
<Input
@type="number"
class="ai-llm-editor__max-prompt-tokens"
step="any"
min="0"
lang="en"
@value={{@model.max_prompt_tokens}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.llms.hints.max_prompt_tokens"}}
/>
</div>
<div class="control-group ai-llm-editor__action_panel">
<DButton
class="btn-primary ai-llm-editor__save"
@action={{this.save}}
@disabled={{this.isSaving}}
>
{{I18n.t "discourse_ai.llms.save"}}
</DButton>
</div>
</form>
</template>
}

View File

@ -0,0 +1,61 @@
import Component from "@glimmer/component";
import { LinkTo } from "@ember/routing";
import icon from "discourse-common/helpers/d-icon";
import i18n from "discourse-common/helpers/i18n";
import I18n from "discourse-i18n";
import AiLlmEditor from "./ai-llm-editor";
export default class AiLlmsListEditor extends Component {
get hasNoLLMElements() {
this.args.llms.length !== 0;
}
<template>
<section class="ai-llms-list-editor admin-detail pull-left">
<div class="ai-llms-list-editor__header">
<h3>{{i18n "discourse_ai.llms.short_title"}}</h3>
{{#unless @currentLlm.isNew}}
<LinkTo
@route="adminPlugins.show.discourse-ai-llms.new"
class="btn btn-small btn-primary"
>
{{icon "plus"}}
<span>{{I18n.t "discourse_ai.llms.new"}}</span>
</LinkTo>
{{/unless}}
</div>
<div class="ai-llms-list-editor__container">
{{#if this.hasNoLLMElements}}
<div class="ai-llms-list-editor__empty_list">
{{icon "robot"}}
{{i18n "discourse_ai.llms.no_llms"}}
</div>
{{else}}
<div class="content-list ai-llms-list-editor__content_list">
<ul>
{{#each @llms as |llm|}}
<li>
<LinkTo
@route="adminPlugins.show.discourse-ai-llms.show"
current-when="true"
@model={{llm}}
>
{{llm.display_name}}
</LinkTo>
</li>
{{/each}}
</ul>
</div>
{{/if}}
<div class="ai-llms-list-editor__current">
{{#if @currentLlm}}
<AiLlmEditor @model={{@currentLlm}} @llms={{@llms}} />
{{/if}}
</div>
</div>
</section>
</template>
}

View File

@ -16,6 +16,10 @@ export default {
label: "discourse_ai.ai_persona.short_title", label: "discourse_ai.ai_persona.short_title",
route: "adminPlugins.show.discourse-ai-personas", route: "adminPlugins.show.discourse-ai-personas",
}, },
{
label: "discourse_ai.llms.short_title",
route: "adminPlugins.show.discourse-ai-llms",
},
]); ]);
}); });
}, },

View File

@ -0,0 +1,32 @@
.ai-llms-list-editor {
&__header {
display: flex;
justify-content: space-between;
align-items: center;
margin: 0 0 1em 0;
h3 {
margin: 0;
}
}
&__container {
display: flex;
flex-direction: row;
align-items: center;
gap: 20px;
width: 100%;
align-items: stretch;
}
&__empty_list,
&__content_list {
min-width: 300px;
}
&__empty_list {
align-content: center;
text-align: center;
font-size: var(--font-up-1);
}
}

View File

@ -194,6 +194,32 @@ en:
uploading: "Uploading..." uploading: "Uploading..."
remove: "Remove upload" remove: "Remove upload"
llms:
short_title: "LLMs"
no_llms: "No LLMs yet"
new: "New Model"
display_name: "Name to display:"
name: "Model name:"
provider: "Service hosting the model:"
tokenizer: "Tokenizer:"
max_prompt_tokens: "Number of tokens for the prompt:"
save: "Save"
saved: "LLM Model Saved"
hints:
max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window."
name: "We include this in the API call to specify which model we'll use."
providers:
aws_bedrock: "AWS Bedrock"
anthropic: "Anthropic"
vllm: "vLLM"
hugging_face: "Hugging Face"
cohere: "Cohere"
open_ai: "OpenAI"
google: "Google"
azure: "Azure"
related_topics: related_topics:
title: "Related Topics" title: "Related Topics"
pill: "Related" pill: "Related"

View File

@ -45,5 +45,10 @@ Discourse::Application.routes.draw do
post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file" post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file"
put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file" put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file"
get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check" get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check"
resources :ai_llms,
only: %i[index create show update],
path: "ai-llms",
controller: "discourse_ai/admin/ai_llms"
end end
end end

View File

@ -0,0 +1,14 @@
# frozen_string_literal: true
class CreateLlmModelTable < ActiveRecord::Migration[7.0]
def change
create_table :llm_models do |t|
t.string :display_name
t.string :name, null: false
t.string :provider, null: false
t.string :tokenizer, null: false
t.integer :max_prompt_tokens, null: false
t.timestamps
end
end
end

View File

@ -150,59 +150,73 @@ module DiscourseAi
def self.guess_model(bot_user) def self.guess_model(bot_user)
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings # HACK(roman): We'll do this until we define how we represent different providers in the bot settings
case bot_user.id guess =
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID case bot_user.id
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
"aws_bedrock:claude-2" if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"open_ai:gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
"vllm:#{mixtral_model}"
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
mixtral_model,
)
"hugging_face:#{mixtral_model}"
else
"ollama:mistral"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-opus",
)
"aws_bedrock:claude-3-opus"
else
"anthropic:claude-3-opus"
end
when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS
"cohere:command-r-plus"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-sonnet",
)
"aws_bedrock:claude-3-sonnet"
else
"anthropic:claude-3-sonnet"
end
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-haiku",
)
"aws_bedrock:claude-3-haiku"
else
"anthropic:claude-3-haiku"
end
else else
"anthropic:claude-2" nil
end end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"open_ai:gpt-4" if guess
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID provider, model_name = guess.split(":")
"open_ai:gpt-4-turbo" llm_model = LlmModel.find_by(provider: provider, name: model_name)
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k" return "custom:#{llm_model.id}" if llm_model
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
"vllm:#{mixtral_model}"
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
mixtral_model,
)
"hugging_face:#{mixtral_model}"
else
"ollama:mistral"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus")
"aws_bedrock:claude-3-opus"
else
"anthropic:claude-3-opus"
end
when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS
"cohere:command-r-plus"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-sonnet",
)
"aws_bedrock:claude-3-sonnet"
else
"anthropic:claude-3-sonnet"
end
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku")
"aws_bedrock:claude-3-haiku"
else
"anthropic:claude-3-haiku"
end
else
nil
end end
guess
end end
def build_placeholder(summary, details, custom_raw: nil) def build_placeholder(summary, details, custom_raw: nil)

View File

@ -25,6 +25,9 @@ module DiscourseAi
] ]
def self.translate_model(model) def self.translate_model(model)
llm_model = LlmModel.find_by(name: model)
return "custom:#{llm_model.id}" if llm_model
return "google:#{model}" if model.start_with? "gemini" return "google:#{model}" if model.start_with? "gemini"
return "open_ai:#{model}" if model.start_with? "gpt" return "open_ai:#{model}" if model.start_with? "gpt"
return "cohere:#{model}" if model.start_with? "command" return "cohere:#{model}" if model.start_with? "command"

View File

@ -30,6 +30,8 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
# provide a buffer of 120 tokens - our function counting is not # provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard # 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens] || 2500) + 50 buffer = (opts[:max_tokens] || 2500) + 50

View File

@ -50,6 +50,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
# Longer term it will have over 1 million # Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now 200_000 # Claude-3 has a 200k context window for now
end end

View File

@ -38,6 +38,8 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
case model_name case model_name
when "command-light" when "command-light"
4096 4096
@ -62,10 +64,6 @@ module DiscourseAi
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg) def system_msg(msg)
cmd_msg = { role: "SYSTEM", message: msg[:content] } cmd_msg = { role: "SYSTEM", message: msg[:content] }

View File

@ -9,14 +9,22 @@ module DiscourseAi
raise NotImplemented raise NotImplemented
end end
def dialect_for(model_name) def all_dialects
dialects = [ [
DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Command,
] ]
end
def available_tokenizers
all_dialects.map(&:tokenizer)
end
def dialect_for(model_name)
dialects = all_dialects
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?
dialects << DiscourseAi::Completions::Dialects::Fake dialects << DiscourseAi::Completions::Dialects::Fake

View File

@ -68,6 +68,8 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
if model_name == "gemini-1.5-pro" if model_name == "gemini-1.5-pro"
# technically we support 1 million tokens, but we're being conservative # technically we support 1 million tokens, but we're being conservative
800_000 800_000

View File

@ -23,6 +23,8 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
32_000 32_000
end end

View File

@ -18,6 +18,14 @@ module DiscourseAi
UNKNOWN_MODEL = Class.new(StandardError) UNKNOWN_MODEL = Class.new(StandardError)
class << self class << self
def provider_names
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
end
def tokenizer_names
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
end
def models_by_provider def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them. # However, since they use the same URL/key settings, there's no reason to duplicate them.
@ -80,36 +88,54 @@ module DiscourseAi
end end
def proxy(model_name) def proxy(model_name)
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
provider_and_model_name = model_name.split(":") provider_and_model_name = model_name.split(":")
provider_name = provider_and_model_name.first provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
if is_custom_model
llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider
model_name_without_prov = llm_model.name
end
dialect_klass = dialect_klass =
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
if is_custom_model
tokenizer = llm_model.tokenizer_class
else
tokenizer = dialect_klass.tokenizer
end
if @canned_response if @canned_response
if @canned_llm && @canned_llm != model_name if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end end
return new(dialect_klass, nil, model_name, gateway: @canned_response) return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
end end
opts = {}
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
gateway_klass = gateway_klass =
DiscourseAi::Completions::Endpoints::Base.endpoint_for( DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name, provider_name,
model_name_without_prov, model_name_without_prov,
) )
new(dialect_klass, gateway_klass, model_name_without_prov) new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
end end
end end
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil) def initialize(dialect_klass, gateway_klass, model_name, opts: {})
@dialect_klass = dialect_klass @dialect_klass = dialect_klass
@gateway_klass = gateway_klass @gateway_klass = gateway_klass
@model_name = model_name @model_name = model_name
@gateway = gateway @gateway = opts[:gateway]
@max_prompt_tokens = opts[:max_prompt_tokens]
end end
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
@ -166,11 +192,18 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer) gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
dialect = dialect_klass.new(prompt, model_name, opts: model_params) dialect =
dialect_klass.new(
prompt,
model_name,
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end end
def max_prompt_tokens def max_prompt_tokens
return @max_prompt_tokens if @max_prompt_tokens.present?
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end end

View File

@ -10,14 +10,22 @@ module DiscourseAi
end end
def self.values def self.values
# do not cache cause settings can change this begin
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| llm_models =
endpoint = DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
models.map do |model_name| models.map do |model_name|
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
end
end
LlmModel.all.each do |model|
llm_models << { name: model.display_name, value: "custom:#{model.id}" }
end end
llm_models
end end
end end
end end

View File

@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop
register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
module ::DiscourseAi module ::DiscourseAi
PLUGIN_NAME = "discourse-ai" PLUGIN_NAME = "discourse-ai"
end end

View File

@ -0,0 +1,9 @@
# frozen_string_literal: true
Fabricator(:llm_model) do
display_name "A good model"
name "gpt-4-turbo"
provider "open_ai"
tokenizer "DiscourseAi::Tokenizers::OpenAi"
max_prompt_tokens 32_000
end

View File

@ -81,7 +81,10 @@ end
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:endpoint) do subject(:endpoint) do
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer) described_class.new(
"mistralai/Mistral-7B-Instruct-v0.2",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }

View File

@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Llm do
DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Mistral,
canned_response, canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2", "hugging_face:Upstage-Llama-2-*-instruct-v2",
gateway: canned_response, opts: {
gateway: canned_response,
},
) )
end end

View File

@ -0,0 +1,68 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Admin::AiLlmsController do
fab!(:admin)
before { sign_in(admin) }
describe "GET #index" do
it "includes all available providers metadata" do
get "/admin/plugins/discourse-ai/ai-llms.json"
expect(response).to be_successful
expect(response.parsed_body["meta"]["providers"]).to contain_exactly(
*DiscourseAi::Completions::Llm.provider_names,
)
end
end
describe "POST #create" do
context "with valid attributes" do
let(:valid_attrs) do
{
display_name: "My cool LLM",
name: "gpt-3.5",
provider: "open_ai",
tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer",
max_prompt_tokens: 16_000,
}
end
it "creates a new LLM model" do
post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs }
created_model = LlmModel.last
expect(created_model.display_name).to eq(valid_attrs[:display_name])
expect(created_model.name).to eq(valid_attrs[:name])
expect(created_model.provider).to eq(valid_attrs[:provider])
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
end
end
end
describe "PUT #update" do
fab!(:llm_model)
context "with valid update params" do
let(:update_attrs) { { provider: "anthropic" } }
it "updates the model" do
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: update_attrs,
}
expect(response.status).to eq(200)
expect(llm_model.reload.provider).to eq(update_attrs[:provider])
end
it "returns a 404 if there is no model with the given Id" do
put "/admin/plugins/discourse-ai/ai-llms/9999999.json"
expect(response.status).to eq(404)
end
end
end
end