FEATURE: Configurable LLMs. (#606)
This PR introduces the concept of "LlmModel" as a new way to quickly add new LLM models without making any code changes. We are releasing this first version and will add incremental improvements, so expect changes. The AI Bot can't fully take advantage of this feature as users are hard-coded. We'll fix this in a separate PR.s
This commit is contained in:
parent
5c02b885ea
commit
62fc7d6ed0
|
@ -0,0 +1,16 @@
|
|||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default DiscourseRoute.extend({
|
||||
async model() {
|
||||
const record = this.store.createRecord("ai-llm");
|
||||
return record;
|
||||
},
|
||||
|
||||
setupController(controller, model) {
|
||||
this._super(controller, model);
|
||||
controller.set(
|
||||
"allLlms",
|
||||
this.modelFor("adminPlugins.show.discourse-ai-llms")
|
||||
);
|
||||
},
|
||||
});
|
|
@ -0,0 +1,17 @@
|
|||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default DiscourseRoute.extend({
|
||||
async model(params) {
|
||||
const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms");
|
||||
const id = parseInt(params.id, 10);
|
||||
return allLlms.findBy("id", id);
|
||||
},
|
||||
|
||||
setupController(controller, model) {
|
||||
this._super(controller, model);
|
||||
controller.set(
|
||||
"allLlms",
|
||||
this.modelFor("adminPlugins.show.discourse-ai-llms")
|
||||
);
|
||||
},
|
||||
});
|
|
@ -0,0 +1,7 @@
|
|||
import DiscourseRoute from "discourse/routes/discourse";
|
||||
|
||||
export default class DiscourseAiAiLlmsRoute extends DiscourseRoute {
|
||||
model() {
|
||||
return this.store.findAll("ai-llm");
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
<AiLlmsListEditor @llms={{this.model}} />
|
|
@ -0,0 +1 @@
|
|||
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />
|
|
@ -0,0 +1 @@
|
|||
<AiLlmsListEditor @llms={{this.allLlms}} @currentLlm={{this.model}} />
|
|
@ -0,0 +1,64 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Admin
|
||||
class AiLlmsController < ::Admin::AdminController
|
||||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||
|
||||
def index
|
||||
llms = LlmModel.all
|
||||
|
||||
render json: {
|
||||
ai_llms:
|
||||
ActiveModel::ArraySerializer.new(
|
||||
llms,
|
||||
each_serializer: LlmModelSerializer,
|
||||
root: false,
|
||||
).as_json,
|
||||
meta: {
|
||||
providers: DiscourseAi::Completions::Llm.provider_names,
|
||||
tokenizers:
|
||||
DiscourseAi::Completions::Llm.tokenizer_names.map { |tn|
|
||||
{ id: tn, name: tn.split("::").last }
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def show
|
||||
llm_model = LlmModel.find(params[:id])
|
||||
render json: LlmModelSerializer.new(llm_model)
|
||||
end
|
||||
|
||||
def create
|
||||
if llm_model = LlmModel.new(ai_llm_params).save
|
||||
render json: { ai_persona: llm_model }, status: :created
|
||||
else
|
||||
render_json_error llm_model
|
||||
end
|
||||
end
|
||||
|
||||
def update
|
||||
llm_model = LlmModel.find(params[:id])
|
||||
|
||||
if llm_model.update(ai_llm_params)
|
||||
render json: llm_model
|
||||
else
|
||||
render_json_error llm_model
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def ai_llm_params
|
||||
params.require(:ai_llm).permit(
|
||||
:display_name,
|
||||
:name,
|
||||
:provider,
|
||||
:tokenizer,
|
||||
:max_prompt_tokens,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class LlmModel < ActiveRecord::Base
|
||||
def tokenizer_class
|
||||
tokenizer.constantize
|
||||
end
|
||||
end
|
|
@ -0,0 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class LlmModelSerializer < ApplicationSerializer
|
||||
root "llm"
|
||||
|
||||
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
|
||||
end
|
|
@ -8,5 +8,10 @@ export default {
|
|||
this.route("new");
|
||||
this.route("show", { path: "/:id" });
|
||||
});
|
||||
|
||||
this.route("discourse-ai-llms", { path: "ai-llms" }, function () {
|
||||
this.route("new");
|
||||
this.route("show", { path: "/:id" });
|
||||
});
|
||||
},
|
||||
};
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
import RestAdapter from "discourse/adapters/rest";
|
||||
|
||||
export default class Adapter extends RestAdapter {
|
||||
jsonMode = true;
|
||||
|
||||
basePath() {
|
||||
return "/admin/plugins/discourse-ai/";
|
||||
}
|
||||
|
||||
pathFor(store, type, findArgs) {
|
||||
// removes underscores which are implemented in base
|
||||
let path =
|
||||
this.basePath(store, type, findArgs) +
|
||||
store.pluralize(this.apiNameFor(type));
|
||||
return this.appendQueryParams(path, findArgs);
|
||||
}
|
||||
|
||||
apiNameFor() {
|
||||
return "ai-llm";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
import RestModel from "discourse/models/rest";
|
||||
|
||||
export default class AiLlm extends RestModel {
|
||||
createProperties() {
|
||||
return this.getProperties(
|
||||
"display_name",
|
||||
"name",
|
||||
"provider",
|
||||
"tokenizer",
|
||||
"max_prompt_tokens"
|
||||
);
|
||||
}
|
||||
|
||||
updateProperties() {
|
||||
const attrs = this.createProperties();
|
||||
attrs.id = this.id;
|
||||
|
||||
return attrs;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { Input } from "@ember/component";
|
||||
import { action } from "@ember/object";
|
||||
import { later } from "@ember/runloop";
|
||||
import { inject as service } from "@ember/service";
|
||||
import DButton from "discourse/components/d-button";
|
||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||
import i18n from "discourse-common/helpers/i18n";
|
||||
import I18n from "discourse-i18n";
|
||||
import ComboBox from "select-kit/components/combo-box";
|
||||
import DTooltip from "float-kit/components/d-tooltip";
|
||||
|
||||
export default class AiLlmEditor extends Component {
|
||||
@service toasts;
|
||||
@service router;
|
||||
|
||||
@tracked isSaving = false;
|
||||
|
||||
get selectedProviders() {
|
||||
const t = (provName) => {
|
||||
return I18n.t(`discourse_ai.llms.providers.${provName}`);
|
||||
};
|
||||
|
||||
return this.args.llms.resultSetMeta.providers.map((prov) => {
|
||||
return { id: prov, name: t(prov) };
|
||||
});
|
||||
}
|
||||
|
||||
@action
|
||||
async save() {
|
||||
this.isSaving = true;
|
||||
const isNew = this.args.model.isNew;
|
||||
|
||||
try {
|
||||
await this.args.model.save();
|
||||
|
||||
if (isNew) {
|
||||
this.args.llms.addObject(this.args.model);
|
||||
this.router.transitionTo(
|
||||
"adminPlugins.show.discourse-ai-llms.show",
|
||||
this.args.model
|
||||
);
|
||||
} else {
|
||||
this.toasts.success({
|
||||
data: { message: I18n.t("discourse_ai.llms.saved") },
|
||||
duration: 2000,
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
popupAjaxError(e);
|
||||
} finally {
|
||||
later(() => {
|
||||
this.isSaving = false;
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
<template>
|
||||
<form class="form-horizontal ai-llm-editor">
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor__display-name"
|
||||
@type="text"
|
||||
@value={{@model.display_name}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.name"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor__name"
|
||||
@type="text"
|
||||
@value={{@model.name}}
|
||||
/>
|
||||
<DTooltip
|
||||
@icon="question-circle"
|
||||
@content={{I18n.t "discourse_ai.llms.hints.name"}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.provider"}}</label>
|
||||
<ComboBox
|
||||
@value={{@model.provider}}
|
||||
@content={{this.selectedProviders}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
|
||||
<ComboBox
|
||||
@value={{@model.tokenizer}}
|
||||
@content={{@llms.resultSetMeta.tokenizers}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
|
||||
<Input
|
||||
@type="number"
|
||||
class="ai-llm-editor__max-prompt-tokens"
|
||||
step="any"
|
||||
min="0"
|
||||
lang="en"
|
||||
@value={{@model.max_prompt_tokens}}
|
||||
/>
|
||||
<DTooltip
|
||||
@icon="question-circle"
|
||||
@content={{I18n.t "discourse_ai.llms.hints.max_prompt_tokens"}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="control-group ai-llm-editor__action_panel">
|
||||
<DButton
|
||||
class="btn-primary ai-llm-editor__save"
|
||||
@action={{this.save}}
|
||||
@disabled={{this.isSaving}}
|
||||
>
|
||||
{{I18n.t "discourse_ai.llms.save"}}
|
||||
</DButton>
|
||||
</div>
|
||||
</form>
|
||||
</template>
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { LinkTo } from "@ember/routing";
|
||||
import icon from "discourse-common/helpers/d-icon";
|
||||
import i18n from "discourse-common/helpers/i18n";
|
||||
import I18n from "discourse-i18n";
|
||||
import AiLlmEditor from "./ai-llm-editor";
|
||||
|
||||
export default class AiLlmsListEditor extends Component {
|
||||
get hasNoLLMElements() {
|
||||
this.args.llms.length !== 0;
|
||||
}
|
||||
|
||||
<template>
|
||||
<section class="ai-llms-list-editor admin-detail pull-left">
|
||||
|
||||
<div class="ai-llms-list-editor__header">
|
||||
<h3>{{i18n "discourse_ai.llms.short_title"}}</h3>
|
||||
{{#unless @currentLlm.isNew}}
|
||||
<LinkTo
|
||||
@route="adminPlugins.show.discourse-ai-llms.new"
|
||||
class="btn btn-small btn-primary"
|
||||
>
|
||||
{{icon "plus"}}
|
||||
<span>{{I18n.t "discourse_ai.llms.new"}}</span>
|
||||
</LinkTo>
|
||||
{{/unless}}
|
||||
</div>
|
||||
|
||||
<div class="ai-llms-list-editor__container">
|
||||
{{#if this.hasNoLLMElements}}
|
||||
<div class="ai-llms-list-editor__empty_list">
|
||||
{{icon "robot"}}
|
||||
{{i18n "discourse_ai.llms.no_llms"}}
|
||||
</div>
|
||||
{{else}}
|
||||
<div class="content-list ai-llms-list-editor__content_list">
|
||||
<ul>
|
||||
{{#each @llms as |llm|}}
|
||||
<li>
|
||||
<LinkTo
|
||||
@route="adminPlugins.show.discourse-ai-llms.show"
|
||||
current-when="true"
|
||||
@model={{llm}}
|
||||
>
|
||||
{{llm.display_name}}
|
||||
</LinkTo>
|
||||
</li>
|
||||
{{/each}}
|
||||
</ul>
|
||||
</div>
|
||||
{{/if}}
|
||||
|
||||
<div class="ai-llms-list-editor__current">
|
||||
{{#if @currentLlm}}
|
||||
<AiLlmEditor @model={{@currentLlm}} @llms={{@llms}} />
|
||||
{{/if}}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
}
|
|
@ -16,6 +16,10 @@ export default {
|
|||
label: "discourse_ai.ai_persona.short_title",
|
||||
route: "adminPlugins.show.discourse-ai-personas",
|
||||
},
|
||||
{
|
||||
label: "discourse_ai.llms.short_title",
|
||||
route: "adminPlugins.show.discourse-ai-llms",
|
||||
},
|
||||
]);
|
||||
});
|
||||
},
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
.ai-llms-list-editor {
|
||||
&__header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin: 0 0 1em 0;
|
||||
|
||||
h3 {
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
&__container {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
width: 100%;
|
||||
align-items: stretch;
|
||||
}
|
||||
|
||||
&__empty_list,
|
||||
&__content_list {
|
||||
min-width: 300px;
|
||||
}
|
||||
|
||||
&__empty_list {
|
||||
align-content: center;
|
||||
text-align: center;
|
||||
font-size: var(--font-up-1);
|
||||
}
|
||||
}
|
|
@ -194,6 +194,32 @@ en:
|
|||
uploading: "Uploading..."
|
||||
remove: "Remove upload"
|
||||
|
||||
llms:
|
||||
short_title: "LLMs"
|
||||
no_llms: "No LLMs yet"
|
||||
new: "New Model"
|
||||
display_name: "Name to display:"
|
||||
name: "Model name:"
|
||||
provider: "Service hosting the model:"
|
||||
tokenizer: "Tokenizer:"
|
||||
max_prompt_tokens: "Number of tokens for the prompt:"
|
||||
save: "Save"
|
||||
saved: "LLM Model Saved"
|
||||
|
||||
hints:
|
||||
max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window."
|
||||
name: "We include this in the API call to specify which model we'll use."
|
||||
|
||||
providers:
|
||||
aws_bedrock: "AWS Bedrock"
|
||||
anthropic: "Anthropic"
|
||||
vllm: "vLLM"
|
||||
hugging_face: "Hugging Face"
|
||||
cohere: "Cohere"
|
||||
open_ai: "OpenAI"
|
||||
google: "Google"
|
||||
azure: "Azure"
|
||||
|
||||
related_topics:
|
||||
title: "Related Topics"
|
||||
pill: "Related"
|
||||
|
|
|
@ -45,5 +45,10 @@ Discourse::Application.routes.draw do
|
|||
post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file"
|
||||
put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file"
|
||||
get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check"
|
||||
|
||||
resources :ai_llms,
|
||||
only: %i[index create show update],
|
||||
path: "ai-llms",
|
||||
controller: "discourse_ai/admin/ai_llms"
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class CreateLlmModelTable < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :llm_models do |t|
|
||||
t.string :display_name
|
||||
t.string :name, null: false
|
||||
t.string :provider, null: false
|
||||
t.string :tokenizer, null: false
|
||||
t.integer :max_prompt_tokens, null: false
|
||||
t.timestamps
|
||||
end
|
||||
end
|
||||
end
|
|
@ -150,6 +150,7 @@ module DiscourseAi
|
|||
|
||||
def self.guess_model(bot_user)
|
||||
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
|
||||
guess =
|
||||
case bot_user.id
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||
|
@ -179,7 +180,9 @@ module DiscourseAi
|
|||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||
"fake:fake"
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus")
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
|
||||
"claude-3-opus",
|
||||
)
|
||||
"aws_bedrock:claude-3-opus"
|
||||
else
|
||||
"anthropic:claude-3-opus"
|
||||
|
@ -195,7 +198,9 @@ module DiscourseAi
|
|||
"anthropic:claude-3-sonnet"
|
||||
end
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku")
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
|
||||
"claude-3-haiku",
|
||||
)
|
||||
"aws_bedrock:claude-3-haiku"
|
||||
else
|
||||
"anthropic:claude-3-haiku"
|
||||
|
@ -203,6 +208,15 @@ module DiscourseAi
|
|||
else
|
||||
nil
|
||||
end
|
||||
|
||||
if guess
|
||||
provider, model_name = guess.split(":")
|
||||
llm_model = LlmModel.find_by(provider: provider, name: model_name)
|
||||
|
||||
return "custom:#{llm_model.id}" if llm_model
|
||||
end
|
||||
|
||||
guess
|
||||
end
|
||||
|
||||
def build_placeholder(summary, details, custom_raw: nil)
|
||||
|
|
|
@ -25,6 +25,9 @@ module DiscourseAi
|
|||
]
|
||||
|
||||
def self.translate_model(model)
|
||||
llm_model = LlmModel.find_by(name: model)
|
||||
return "custom:#{llm_model.id}" if llm_model
|
||||
|
||||
return "google:#{model}" if model.start_with? "gemini"
|
||||
return "open_ai:#{model}" if model.start_with? "gpt"
|
||||
return "cohere:#{model}" if model.start_with? "command"
|
||||
|
|
|
@ -30,6 +30,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
|
||||
# provide a buffer of 120 tokens - our function counting is not
|
||||
# 100% accurate and getting numbers to align exactly is very hard
|
||||
buffer = (opts[:max_tokens] || 2500) + 50
|
||||
|
|
|
@ -50,6 +50,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
# Longer term it will have over 1 million
|
||||
200_000 # Claude-3 has a 200k context window for now
|
||||
end
|
||||
|
|
|
@ -38,6 +38,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
|
||||
case model_name
|
||||
when "command-light"
|
||||
4096
|
||||
|
@ -62,10 +64,6 @@ module DiscourseAi
|
|||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def tools_dialect
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
cmd_msg = { role: "SYSTEM", message: msg[:content] }
|
||||
|
||||
|
|
|
@ -9,14 +9,22 @@ module DiscourseAi
|
|||
raise NotImplemented
|
||||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
dialects = [
|
||||
def all_dialects
|
||||
[
|
||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Command,
|
||||
]
|
||||
end
|
||||
|
||||
def available_tokenizers
|
||||
all_dialects.map(&:tokenizer)
|
||||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
dialects = all_dialects
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
dialects << DiscourseAi::Completions::Dialects::Fake
|
||||
|
|
|
@ -68,6 +68,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
|
||||
if model_name == "gemini-1.5-pro"
|
||||
# technically we support 1 million tokens, but we're being conservative
|
||||
800_000
|
||||
|
|
|
@ -23,6 +23,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
|
||||
32_000
|
||||
end
|
||||
|
||||
|
|
|
@ -18,6 +18,14 @@ module DiscourseAi
|
|||
UNKNOWN_MODEL = Class.new(StandardError)
|
||||
|
||||
class << self
|
||||
def provider_names
|
||||
%w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
||||
end
|
||||
|
||||
def tokenizer_names
|
||||
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
|
||||
end
|
||||
|
||||
def models_by_provider
|
||||
# ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure.
|
||||
# However, since they use the same URL/key settings, there's no reason to duplicate them.
|
||||
|
@ -80,36 +88,54 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def proxy(model_name)
|
||||
# We are in the process of transitioning to always use objects here.
|
||||
# We'll live with this hack for a while.
|
||||
provider_and_model_name = model_name.split(":")
|
||||
|
||||
provider_name = provider_and_model_name.first
|
||||
model_name_without_prov = provider_and_model_name[1..].join
|
||||
is_custom_model = provider_name == "custom"
|
||||
|
||||
if is_custom_model
|
||||
llm_model = LlmModel.find(model_name_without_prov)
|
||||
provider_name = llm_model.provider
|
||||
model_name_without_prov = llm_model.name
|
||||
end
|
||||
|
||||
dialect_klass =
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov)
|
||||
|
||||
if is_custom_model
|
||||
tokenizer = llm_model.tokenizer_class
|
||||
else
|
||||
tokenizer = dialect_klass.tokenizer
|
||||
end
|
||||
|
||||
if @canned_response
|
||||
if @canned_llm && @canned_llm != model_name
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||
end
|
||||
return new(dialect_klass, nil, model_name, gateway: @canned_response)
|
||||
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
|
||||
end
|
||||
|
||||
opts = {}
|
||||
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
||||
|
||||
gateway_klass =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(
|
||||
provider_name,
|
||||
model_name_without_prov,
|
||||
)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov)
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil)
|
||||
def initialize(dialect_klass, gateway_klass, model_name, opts: {})
|
||||
@dialect_klass = dialect_klass
|
||||
@gateway_klass = gateway_klass
|
||||
@model_name = model_name
|
||||
@gateway = gateway
|
||||
@gateway = opts[:gateway]
|
||||
@max_prompt_tokens = opts[:max_prompt_tokens]
|
||||
end
|
||||
|
||||
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
||||
|
@ -166,11 +192,18 @@ module DiscourseAi
|
|||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||
|
||||
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
|
||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
|
||||
dialect =
|
||||
dialect_klass.new(
|
||||
prompt,
|
||||
model_name,
|
||||
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
|
||||
)
|
||||
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return @max_prompt_tokens if @max_prompt_tokens.present?
|
||||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
||||
end
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def self.values
|
||||
# do not cache cause settings can change this
|
||||
begin
|
||||
llm_models =
|
||||
DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models|
|
||||
endpoint =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first)
|
||||
|
@ -19,6 +20,13 @@ module DiscourseAi
|
|||
{ name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" }
|
||||
end
|
||||
end
|
||||
|
||||
LlmModel.all.each do |model|
|
||||
llm_models << { name: model.display_name, value: "custom:#{model.id}" }
|
||||
end
|
||||
|
||||
llm_models
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/common/dashboard.scss"
|
|||
register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop
|
||||
register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile
|
||||
|
||||
register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss"
|
||||
|
||||
module ::DiscourseAi
|
||||
PLUGIN_NAME = "discourse-ai"
|
||||
end
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
Fabricator(:llm_model) do
|
||||
display_name "A good model"
|
||||
name "gpt-4-turbo"
|
||||
provider "open_ai"
|
||||
tokenizer "DiscourseAi::Tokenizers::OpenAi"
|
||||
max_prompt_tokens 32_000
|
||||
end
|
|
@ -81,7 +81,10 @@ end
|
|||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||
subject(:endpoint) do
|
||||
described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer)
|
||||
described_class.new(
|
||||
"mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
)
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" }
|
||||
|
|
|
@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
canned_response,
|
||||
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
||||
opts: {
|
||||
gateway: canned_response,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Admin::AiLlmsController do
|
||||
fab!(:admin)
|
||||
|
||||
before { sign_in(admin) }
|
||||
|
||||
describe "GET #index" do
|
||||
it "includes all available providers metadata" do
|
||||
get "/admin/plugins/discourse-ai/ai-llms.json"
|
||||
expect(response).to be_successful
|
||||
|
||||
expect(response.parsed_body["meta"]["providers"]).to contain_exactly(
|
||||
*DiscourseAi::Completions::Llm.provider_names,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "POST #create" do
|
||||
context "with valid attributes" do
|
||||
let(:valid_attrs) do
|
||||
{
|
||||
display_name: "My cool LLM",
|
||||
name: "gpt-3.5",
|
||||
provider: "open_ai",
|
||||
tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer",
|
||||
max_prompt_tokens: 16_000,
|
||||
}
|
||||
end
|
||||
|
||||
it "creates a new LLM model" do
|
||||
post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs }
|
||||
|
||||
created_model = LlmModel.last
|
||||
|
||||
expect(created_model.display_name).to eq(valid_attrs[:display_name])
|
||||
expect(created_model.name).to eq(valid_attrs[:name])
|
||||
expect(created_model.provider).to eq(valid_attrs[:provider])
|
||||
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
|
||||
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "PUT #update" do
|
||||
fab!(:llm_model)
|
||||
|
||||
context "with valid update params" do
|
||||
let(:update_attrs) { { provider: "anthropic" } }
|
||||
|
||||
it "updates the model" do
|
||||
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
|
||||
params: {
|
||||
ai_llm: update_attrs,
|
||||
}
|
||||
|
||||
expect(response.status).to eq(200)
|
||||
expect(llm_model.reload.provider).to eq(update_attrs[:provider])
|
||||
end
|
||||
|
||||
it "returns a 404 if there is no model with the given Id" do
|
||||
put "/admin/plugins/discourse-ai/ai-llms/9999999.json"
|
||||
|
||||
expect(response.status).to eq(404)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue