FEATURE: LLM Triage support for systemless models. (#757)

* FEATURE: LLM Triage support for systemless models.

This change adds support for OSS models without support for system messages. LlmTriage's system message field is no longer mandatory. We now send the post contents in a separate user message.

* Models using Ollama can also disable system prompts
This commit is contained in:
Roman Rizzi 2024-08-21 11:41:55 -03:00 committed by GitHub
parent 97fc822cb6
commit 64641b6175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 138 additions and 42 deletions

View File

@ -4,7 +4,9 @@ export default DiscourseRoute.extend({
async model(params) {
const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms");
const id = parseInt(params.id, 10);
return allLlms.findBy("id", id);
const record = allLlms.findBy("id", id);
record.provider_params = record.provider_params || {};
return record;
},
setupController(controller, model) {

View File

@ -117,11 +117,21 @@ module DiscourseAi
new_url = params.dig(:ai_llm, :url)
permitted[:url] = new_url if permit_url && new_url
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a
received_prov_params = params.dig(:ai_llm, :provider_params)
permitted[:provider_params] = received_prov_params.slice(
*extra_field_names,
).permit! if !extra_field_names.empty? && received_prov_params.present?
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym)
if extra_field_names.present?
received_prov_params =
params.dig(:ai_llm, :provider_params)&.slice(*extra_field_names.keys)
if received_prov_params.present?
received_prov_params.each do |pname, value|
if extra_field_names[pname.to_sym] == :checkbox
received_prov_params[pname] = ActiveModel::Type::Boolean.new.cast(value)
end
end
permitted[:provider_params] = received_prov_params.permit!
end
end
permitted
end

View File

@ -17,12 +17,20 @@ class LlmModel < ActiveRecord::Base
def self.provider_params
{
aws_bedrock: {
url_editable: false,
fields: %i[access_key_id region],
access_key_id: :text,
region: :text,
},
open_ai: {
url_editable: true,
fields: %i[organization],
organization: :text,
},
hugging_face: {
disable_system_prompt: :checkbox,
},
vllm: {
disable_system_prompt: :checkbox,
},
ollama: {
disable_system_prompt: :checkbox,
},
}
end

View File

@ -7,6 +7,7 @@ import { action, computed } from "@ember/object";
import { LinkTo } from "@ember/routing";
import { later } from "@ember/runloop";
import { inject as service } from "@ember/service";
import { eq } from "truth-helpers";
import DButton from "discourse/components/d-button";
import DToggleSwitch from "discourse/components/d-toggle-switch";
import Avatar from "discourse/helpers/bound-avatar-template";
@ -52,9 +53,9 @@ export default class AiLlmEditorForm extends Component {
return this.testRunning || this.testResult !== null;
}
@computed("args.model.provider")
get canEditURL() {
// Explicitly false.
return this.metaProviderParams.url_editable !== false;
return this.args.model.provider === "aws_bedrock";
}
get modulesUsingModel() {
@ -227,18 +228,24 @@ export default class AiLlmEditorForm extends Component {
<DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" />
</div>
</div>
{{#each this.metaProviderParams.fields as |field|}}
<div class="control-group">
{{#each-in this.metaProviderParams as |field type|}}
<div class="control-group ai-llm-editor-provider-param__{{type}}">
<label>{{I18n.t
(concat "discourse_ai.llms.provider_fields." field)
}}</label>
<Input
@type="text"
@value={{mut (get @model.provider_params field)}}
class="ai-llm-editor-input ai-llm-editor__{{field}}"
/>
{{#if (eq type "checkbox")}}
<Input
@type={{type}}
@checked={{mut (get @model.provider_params field)}}
/>
{{else}}
<Input
@type={{type}}
@value={{mut (get @model.provider_params field)}}
/>
{{/if}}
</div>
{{/each}}
{{/each-in}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox

View File

@ -18,6 +18,15 @@
width: 350px;
}
.ai-llm-editor-provider-param {
&__checkbox {
display: flex;
align-items: flex-start;
flex-direction: row-reverse;
justify-content: left;
}
}
.fk-d-tooltip__icon {
padding-left: 0.25em;
color: var(--primary-medium);

View File

@ -273,6 +273,7 @@ en:
access_key_id: "AWS Bedrock Access key ID"
region: "AWS Bedrock Region"
organization: "Optional OpenAI Organization ID"
disable_system_prompt: "Disable system message in prompts"
related_topics:
title: "Related Topics"

View File

@ -4,7 +4,6 @@ en:
llm_triage:
title: Triage posts using AI
description: "Triage posts using a large language model"
system_prompt_missing_post_placeholder: "System prompt must contain a placeholder for the post: %%POST%%"
flagged_post: |
<div>Response from the model:</div>
<p>%%LLM_RESPONSE%%</p>

View File

@ -9,17 +9,7 @@ if defined?(DiscourseAutomation)
triggerables %i[post_created_edited]
field :system_prompt,
component: :message,
required: true,
validator: ->(input) do
if !input.include?("%%POST%%")
I18n.t(
"discourse_automation.scriptables.llm_triage.system_prompt_missing_post_placeholder",
)
end
end,
accepts_placeholders: true
field :system_prompt, component: :message, required: false
field :search_for_text, component: :text, required: true
field :model,
component: :choices,

View File

@ -21,15 +21,9 @@ module DiscourseAi
raise ArgumentError, "llm_triage: no action specified!"
end
post_template = +""
post_template << "title: #{post.topic.title}\n"
post_template << "#{post.raw}"
filled_system_prompt = system_prompt.sub("%%POST%%", post_template)
if filled_system_prompt == system_prompt
raise ArgumentError, "llm_triage: system_prompt does not contain %%POST%% placeholder"
end
s_prompt = system_prompt.to_s.sub("%%POST%%", "") # Backwards-compat. We no longer sub this.
prompt = DiscourseAi::Completions::Prompt.new(s_prompt)
prompt.push(type: :user, content: "title: #{post.topic.title}\n#{post.raw}")
result = nil
@ -37,7 +31,7 @@ module DiscourseAi
result =
llm.generate(
filled_system_prompt,
prompt,
temperature: 0,
max_tokens: 700, # ~500 words
user: Discourse.system_user,

View File

@ -24,6 +24,18 @@ module DiscourseAi
32_000
end
def translate
translated = super
return translated unless llm_model.lookup_custom_param("disable_system_prompt")
system_and_user_msgs = translated.shift(2)
user_msg = system_and_user_msgs.last
user_msg[:content] = [system_and_user_msgs.first[:content], user_msg[:content]].join("\n")
translated.unshift(user_msg)
end
private
def system_msg(msg)

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
context "when system prompts are disabled" do
it "merges the system prompt into the first message" do
system_msg = "This is a system message"
user_msg = "user message"
prompt =
DiscourseAi::Completions::Prompt.new(
system_msg,
messages: [{ type: :user, content: user_msg }],
)
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: true })
translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(1)
expect(translated_messages).to contain_exactly(
{ role: "user", content: [system_msg, user_msg].join("\n") },
)
end
end
context "when system prompts are enabled" do
it "includes system and user messages separately" do
system_msg = "This is a system message"
user_msg = "user message"
prompt =
DiscourseAi::Completions::Prompt.new(
system_msg,
messages: [{ type: :user, content: user_msg }],
)
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: false })
translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(2)
expect(translated_messages).to contain_exactly(
{ role: "system", content: system_msg },
{ role: "user", content: user_msg },
)
end
end
end

View File

@ -136,6 +136,24 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(created_model.lookup_custom_param("region")).to eq("us-east-1")
expect(created_model.lookup_custom_param("access_key_id")).to eq("test")
end
it "supports boolean values" do
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm:
valid_attrs.merge(
provider: "vllm",
provider_params: {
disable_system_prompt: true,
},
),
}
created_model = LlmModel.last
expect(response.status).to eq(201)
expect(created_model.lookup_custom_param("disable_system_prompt")).to eq(true)
end
end
end