FEATURE: Configurable LLMs. (#606)

This PR introduces the concept of "LlmModel" as a new way to quickly add new LLM models without making any code changes. We are releasing this first version and will add incremental improvements, so expect changes.

The AI Bot can't fully take advantage of this feature as users are hard-coded. We'll fix this in a separate PR.s
This commit is contained in:
Roman Rizzi 2024-05-13 12:46:42 -03:00 committed by GitHub
parent 5c02b885ea
commit 62fc7d6ed0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 656 additions and 70 deletions

View File

@ -0,0 +1,16 @@
import DiscourseRoute from "discourse/routes/discourse";
export default DiscourseRoute.extend({
async model() {
const record = this.store.createRecord("ai-llm");
return record;
},
setupController(controller, model) {
this._super(controller, model);
controller.set(
"allLlms",
this.modelFor("adminPlugins.show.discourse-ai-llms")
);
},
});

View File

@ -0,0 +1,17 @@
import DiscourseRoute from "discourse/routes/discourse";
export default DiscourseRoute.extend({
async model(params) {
const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms");
const id = parseInt(params.id, 10);
return allLlms.findBy("id", id);
},
setupController(controller, model) {
this._super(controller, model);
controller.set(
"allLlms",
this.modelFor("adminPlugins.show.discourse-ai-llms")
);
},
});

View File

@ -0,0 +1,7 @@
import DiscourseRoute from "discourse/routes/discourse";
export default class DiscourseAiAiLlmsRoute extends DiscourseRoute {
model() {
return this.store.findAll("ai-llm");
}
}

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.model}} />

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />

View File

@ -0,0 +1 @@
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />

View File

@ -0,0 +1,64 @@
# frozen_string_literal: true
module DiscourseAi
module Admin
class AiLlmsController < ::Admin::AdminController
requires_plugin ::DiscourseAi::PLUGIN_NAME
def index
llms = LlmModel.all
render json: {
ai_llms:
ActiveModel::ArraySerializer.new(
llms,
each_serializer: LlmModelSerializer,
root: false,
).as_json,
meta: {
providers: DiscourseAi::Completions::Llm.provider_names,
tokenizers:
DiscourseAi::Completions::Llm.tokenizer_names.map { |tn|
{ id: tn, name: tn.split("::").last }
},
},
}
end
def show
llm_model = LlmModel.find(params[:id])
render json: LlmModelSerializer.new(llm_model)
end
def create
if llm_model = LlmModel.new(ai_llm_params).save
render json: { ai_persona: llm_model }, status: :created
else
render_json_error llm_model
end
end
def update
llm_model = LlmModel.find(params[:id])
if llm_model.update(ai_llm_params)
render json: llm_model
else
render_json_error llm_model
end
end
private
def ai_llm_params
params.require(:ai_llm).permit(
:display_name,
:name,
:provider,
:tokenizer,
:max_prompt_tokens,
)
end
end
end
end

7
app/models/llm_model.rb Normal file
View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class LlmModel < ActiveRecord::Base
def tokenizer_class
tokenizer.constantize
end
end

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class LlmModelSerializer < ApplicationSerializer
root "llm"
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
end

View File

@ -8,5 +8,10 @@ export default {
this.route("new");
this.route("show", { path: "/:id" });
});
this.route("discourse-ai-llms", { path: "ai-llms" }, function () {
this.route("new");
this.route("show", { path: "/:id" });
});
},
};

View File

@ -0,0 +1,21 @@
import RestAdapter from "discourse/adapters/rest";
export default class Adapter extends RestAdapter {
jsonMode = true;
basePath() {
return "/admin/plugins/discourse-ai/";
}
pathFor(store, type, findArgs) {
// removes underscores which are implemented in base
let path =
this.basePath(store, type, findArgs) +
store.pluralize(this.apiNameFor(type));
return this.appendQueryParams(path, findArgs);
}
apiNameFor() {
return "ai-llm";
}
}

View File

@ -0,0 +1,20 @@
import RestModel from "discourse/models/rest";
export default class AiLlm extends RestModel {
createProperties() {
return this.getProperties(
"display_name",
"name",
"provider",
"tokenizer",
"max_prompt_tokens"
);
}
updateProperties() {
const attrs = this.createProperties();
attrs.id = this.id;
return attrs;
}
}

View File

@ -0,0 +1,122 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { Input } from "@ember/component";
import { action } from "@ember/object";
import { later } from "@ember/runloop";
import { inject as service } from "@ember/service";
import DButton from "discourse/components/d-button";
import { popupAjaxError } from "discourse/lib/ajax-error";
import i18n from "discourse-common/helpers/i18n";
import I18n from "discourse-i18n";
import ComboBox from "select-kit/components/combo-box";
import DTooltip from "float-kit/components/d-tooltip";
export default class AiLlmEditor extends Component {
@service toasts;
@service router;
@tracked isSaving = false;
get selectedProviders() {
const t = (provName) => {
return I18n.t(`discourse_ai.llms.providers.${provName}`);
};
return this.args.llms.resultSetMeta.providers.map((prov) => {
return { id: prov, name: t(prov) };
});
}
@action
async save() {
this.isSaving = true;
const isNew = this.args.model.isNew;
try {
await this.args.model.save();
if (isNew) {
this.args.llms.addObject(this.args.model);
this.router.transitionTo(
"adminPlugins.show.discourse-ai-llms.show",
this.args.model
);
} else {
this.toasts.success({
data: { message: I18n.t("discourse_ai.llms.saved") },
duration: 2000,
});
}
} catch (e) {
popupAjaxError(e);
} finally {
later(() => {
this.isSaving = false;
}, 1000);
}
}
<template>
<form class="form-horizontal ai-llm-editor">
<div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
<Input
class="ai-llm-editor__display-name"
@type="text"
@value={{@model.display_name}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.llms.name"}}</label>
<Input
class="ai-llm-editor__name"
@type="text"
@value={{@model.name}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.llms.hints.name"}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.provider"}}</label>
<ComboBox
@value={{@model.provider}}
@content={{this.selectedProviders}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox
@value={{@model.tokenizer}}
@content={{@llms.resultSetMeta.tokenizers}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
<Input
@type="number"
class="ai-llm-editor__max-prompt-tokens"
step="any"
min="0"
lang="en"
@value={{@model.max_prompt_tokens}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.llms.hints.max_prompt_tokens"}}
/>
</div>
<div class="control-group ai-llm-editor__action_panel">
<DButton
class="btn-primary ai-llm-editor__save"
@action={{this.save}}
@disabled={{this.isSaving}}
>
{{I18n.t "discourse_ai.llms.save"}}
</DButton>
</div>
</form>
</template>
}

View File

@ -0,0 +1,61 @@
import Component from "@glimmer/component";
import { LinkTo } from "@ember/routing";
import icon from "discourse-common/helpers/d-icon";
import i18n from "discourse-common/helpers/i18n";
import I18n from "discourse-i18n";
import AiLlmEditor from "./ai-llm-editor";
export default class AiLlmsListEditor extends Component {
get hasNoLLMElements() {
this.args.llms.length !== 0;
}
<template>
<section class="ai-llms-list-editor admin-detail pull-left">
<div class="ai-llms-list-editor__header">
<h3>{{i18n "discourse_ai.llms.short_title"}}</h3>
{{#unless @currentLlm.isNew}}
<LinkTo
@route="adminPlugins.show.discourse-ai-llms.new"
class="btn btn-small btn-primary"
>
{{icon "plus"}}
<span>{{I18n.t "discourse_ai.llms.new"}}</span>
</LinkTo>
{{/unless}}
</div>
<div class="ai-llms-list-editor__container">
{{#if this.hasNoLLMElements}}
<div class="ai-llms-list-editor__empty_list">
{{icon "robot"}}
{{i18n "discourse_ai.llms.no_llms"}}
</div>
{{else}}
<div class="content-list ai-llms-list-editor__content_list">
<ul>
{{#each @llms as |llm|}}
<li>
<LinkTo
@route="adminPlugins.show.discourse-ai-llms.show"
current-when="true"
@model={{llm}}
>
{{llm.display_name}}
</LinkTo>
</li>
{{/each}}
</ul>
</div>
{{/if}}
<div class="ai-llms-list-editor__current">
{{#if @currentLlm}}
<AiLlmEditor @model={{@currentLlm}} @llms={{@llms}} />
{{/if}}
</div>
</div>
</section>
</template>
}

View File

@ -16,6 +16,10 @@ export default {
label: "discourse_ai.ai_persona.short_title",
route: "adminPlugins.show.discourse-ai-personas",
},
{
label: "discourse_ai.llms.short_title",
route: "adminPlugins.show.discourse-ai-llms",
},
]);
});
},

View File

@ -0,0 +1,32 @@
.ai-llms-list-editor {
&__header {
display: flex;
justify-content: space-between;
align-items: center;
margin: 0 0 1em 0;
h3 {
margin: 0;
}
}
&__container {
display: flex;
flex-direction: row;
align-items: center;
gap: 20px;
width: 100%;
align-items: stretch;
}
&__empty_list,
&__content_list {
min-width: 300px;
}
&__empty_list {
align-content: center;
text-align: center;
font-size: var(--font-up-1);
}
}

View File

@ -194,6 +194,32 @@ en:
uploading: "Uploading..."
remove: "Remove upload"
llms:
short_title: "LLMs"
no_llms: "No LLMs yet"
new: "New Model"
display_name: "Name to display:"
name: "Model name:"
provider: "Service hosting the model:"
tokenizer: "Tokenizer:"
max_prompt_tokens: "Number of tokens for the prompt:"
save: "Save"
saved: "LLM Model Saved"
hints:
max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window."
name: "We include this in the API call to specify which model we'll use."
providers:
aws_bedrock: "AWS Bedrock"
anthropic: "Anthropic"
vllm: "vLLM"
hugging_face: "Hugging Face"
cohere: "Cohere"
open_ai: "OpenAI"
google: "Google"
azure: "Azure"
related_topics:
title: "Related Topics"
pill: "Related"

View File

@ -45,5 +45,10 @@ Discourse::Application.routes.draw do
post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file"
put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file"
get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check"
resources :ai_llms,
only: %i[index create show update],
path: "ai-llms",
controller: "discourse_ai/admin/ai_llms"
end
end

View File

@ -0,0 +1,14 @@
# frozen_string_literal: true
class CreateLlmModelTable < ActiveRecord::Migration[7.0]
def change
create_table :llm_models do |t|
t.string :display_name
t.string :name, null: false
t.string :provider, null: false
t.string :tokenizer, null: false
t.integer :max_prompt_tokens, null: false
t.timestamps
end
end
end

View File

@ -150,6 +150,7 @@ module DiscourseAi
def self.guess_model(bot_user)
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
guess =
case bot_user.id
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
@ -179,7 +180,9 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus")
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-opus",
)
"aws_bedrock:claude-3-opus"
else
"anthropic:claude-3-opus"
@ -195,7 +198,9 @@ module DiscourseAi
"anthropic:claude-3-sonnet"
end
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku")
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-haiku",
)
"aws_bedrock:claude-3-haiku"
else
"anthropic:claude-3-haiku"
@ -203,6 +208,15 @@ module DiscourseAi
else
nil
end
if guess
provider, model_name = guess.split(":")
llm_model = LlmModel.find_by(provider: provider, name: model_name)
return "custom:#{llm_model.id}" if llm_model
end
guess
end
def build_placeholder(summary, details, custom_raw: nil)

View File

@ -25,6 +25,9 @@ module DiscourseAi
]
def self.translate_model(model)
llm_model = LlmModel.find_by(name: model)
return "custom:#{llm_model.id}" if llm_model
return "google:#{model}" if model.start_with? "gemini"
return "open_ai:#{model}" if model.start_with? "gpt"
return "cohere:#{model}" if model.start_with? "command"

View File

@ -30,6 +30,8 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens] || 2500) + 50

View File

@ -50,6 +50,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
end

View File

@ -38,6 +38,8 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
case model_name
when "command-light"
4096
@ -62,10 +64,6 @@ module DiscourseAi
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg)
cmd_msg = { role: "SYSTEM", message: msg[:content] }

View File

@ -9,14 +9,22 @@ module DiscourseAi
raise NotImplemented
end
def dialect_for(model_name)
dialects = [
def all_dialects
[
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
]
end
def available_tokenizers
all_dialects.map(&:tokenizer)
end
def dialect_for(model_name)
dialects = all_dialects
if Rails.env.test? || Rails.env.development?
dialects << DiscourseAi::Completions::Dialects::Fake

View File

@ -68,6 +68,8 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
if model_name == "gemini-1.5-pro"
# technically we support 1 million tokens, but we're being conservative
800_000

View File

@ -23,6 +23,8 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
32_000
end

View File

@ -18,6 +18,14 @@ module DiscourseAi
UNKNOWN_MODEL = Class.new(StandardError)
class << self
def provider_names
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
end
def tokenizer_names
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
end
def models_by_provider
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
# However, since they use the same URL/key settings, there's no reason to duplicate them.
@ -80,36 +88,54 @@ module DiscourseAi
end
def proxy(model_name)
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
provider_and_model_name = model_name.split(":")
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
if is_custom_model
llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider
model_name_without_prov = llm_model.name
end
dialect_klass =
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
if is_custom_model
tokenizer = llm_model.tokenizer_class
else
tokenizer = dialect_klass.tokenizer
end
if @canned_response
if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end
return new(dialect_klass, nil, model_name, gateway: @canned_response)
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
end
opts = {}
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
gateway_klass =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
provider_name,
model_name_without_prov,
)
new(dialect_klass, gateway_klass, model_name_without_prov)
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
end
end
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil)
def initialize(dialect_klass, gateway_klass, model_name, opts: {})
@dialect_klass = dialect_klass
@gateway_klass = gateway_klass
@model_name = model_name
@gateway = gateway
@gateway = opts[:gateway]
@max_prompt_tokens = opts[:max_prompt_tokens]
end
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
@ -166,11 +192,18 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
dialect =
dialect_klass.new(
prompt,
model_name,
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end
def max_prompt_tokens
return @max_prompt_tokens if @max_prompt_tokens.present?
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end

View File

@ -10,7 +10,8 @@ module DiscourseAi
end
def self.values
# do not cache cause settings can change this
begin
llm_models =
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
endpoint =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
@ -19,6 +20,13 @@ module DiscourseAi
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
end
end
LlmModel.all.each do |model|
llm_models << { name: model.display_name, value: "custom:#{model.id}" }
end
llm_models
end
end
end
end

View File

@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop
register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
module ::DiscourseAi
PLUGIN_NAME = "discourse-ai"
end

View File

@ -0,0 +1,9 @@
# frozen_string_literal: true
Fabricator(:llm_model) do
display_name "A good model"
name "gpt-4-turbo"
provider "open_ai"
tokenizer "DiscourseAi::Tokenizers::OpenAi"
max_prompt_tokens 32_000
end

View File

@ -81,7 +81,10 @@ end
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:endpoint) do
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
described_class.new(
"mistralai/Mistral-7B-Instruct-v0.2",
DiscourseAi::Tokenizer::MixtralTokenizer,
)
end
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }

View File

@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Llm do
DiscourseAi::Completions::Dialects::Mistral,
canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2",
opts: {
gateway: canned_response,
},
)
end

View File

@ -0,0 +1,68 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Admin::AiLlmsController do
fab!(:admin)
before { sign_in(admin) }
describe "GET #index" do
it "includes all available providers metadata" do
get "/admin/plugins/discourse-ai/ai-llms.json"
expect(response).to be_successful
expect(response.parsed_body["meta"]["providers"]).to contain_exactly(
*DiscourseAi::Completions::Llm.provider_names,
)
end
end
describe "POST #create" do
context "with valid attributes" do
let(:valid_attrs) do
{
display_name: "My cool LLM",
name: "gpt-3.5",
provider: "open_ai",
tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer",
max_prompt_tokens: 16_000,
}
end
it "creates a new LLM model" do
post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs }
created_model = LlmModel.last
expect(created_model.display_name).to eq(valid_attrs[:display_name])
expect(created_model.name).to eq(valid_attrs[:name])
expect(created_model.provider).to eq(valid_attrs[:provider])
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
end
end
end
describe "PUT #update" do
fab!(:llm_model)
context "with valid update params" do
let(:update_attrs) { { provider: "anthropic" } }
it "updates the model" do
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: update_attrs,
}
expect(response.status).to eq(200)
expect(llm_model.reload.provider).to eq(update_attrs[:provider])
end
it "returns a 404 if there is no model with the given Id" do
put "/admin/plugins/discourse-ai/ai-llms/9999999.json"
expect(response.status).to eq(404)
end
end
end
end