FEATURE: LLM Triage support for systemless models. (#757)

* FEATURE: LLM Triage support for systemless models.

This change adds support for OSS models without support for system messages. LlmTriage's system message field is no longer mandatory. We now send the post contents in a separate user message.

* Models using Ollama can also disable system prompts
This commit is contained in:
Roman Rizzi 2024-08-21 11:41:55 -03:00 committed by GitHub
parent 97fc822cb6
commit 64641b6175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 138 additions and 42 deletions

View File

@ -4,7 +4,9 @@ export default DiscourseRoute.extend({
async model(params) { async model(params) {
const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms"); const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms");
const id = parseInt(params.id, 10); const id = parseInt(params.id, 10);
return allLlms.findBy("id", id); const record = allLlms.findBy("id", id);
record.provider_params = record.provider_params || {};
return record;
}, },
setupController(controller, model) { setupController(controller, model) {

View File

@ -117,11 +117,21 @@ module DiscourseAi
new_url = params.dig(:ai_llm, :url) new_url = params.dig(:ai_llm, :url)
permitted[:url] = new_url if permit_url && new_url permitted[:url] = new_url if permit_url && new_url
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a extra_field_names = LlmModel.provider_params.dig(provider&.to_sym)
received_prov_params = params.dig(:ai_llm, :provider_params) if extra_field_names.present?
permitted[:provider_params] = received_prov_params.slice( received_prov_params =
*extra_field_names, params.dig(:ai_llm, :provider_params)&.slice(*extra_field_names.keys)
).permit! if !extra_field_names.empty? && received_prov_params.present?
if received_prov_params.present?
received_prov_params.each do |pname, value|
if extra_field_names[pname.to_sym] == :checkbox
received_prov_params[pname] = ActiveModel::Type::Boolean.new.cast(value)
end
end
permitted[:provider_params] = received_prov_params.permit!
end
end
permitted permitted
end end

View File

@ -17,12 +17,20 @@ class LlmModel < ActiveRecord::Base
def self.provider_params def self.provider_params
{ {
aws_bedrock: { aws_bedrock: {
url_editable: false, access_key_id: :text,
fields: %i[access_key_id region], region: :text,
}, },
open_ai: { open_ai: {
url_editable: true, organization: :text,
fields: %i[organization], },
hugging_face: {
disable_system_prompt: :checkbox,
},
vllm: {
disable_system_prompt: :checkbox,
},
ollama: {
disable_system_prompt: :checkbox,
}, },
} }
end end

View File

@ -7,6 +7,7 @@ import { action, computed } from "@ember/object";
import { LinkTo } from "@ember/routing"; import { LinkTo } from "@ember/routing";
import { later } from "@ember/runloop"; import { later } from "@ember/runloop";
import { inject as service } from "@ember/service"; import { inject as service } from "@ember/service";
import { eq } from "truth-helpers";
import DButton from "discourse/components/d-button"; import DButton from "discourse/components/d-button";
import DToggleSwitch from "discourse/components/d-toggle-switch"; import DToggleSwitch from "discourse/components/d-toggle-switch";
import Avatar from "discourse/helpers/bound-avatar-template"; import Avatar from "discourse/helpers/bound-avatar-template";
@ -52,9 +53,9 @@ export default class AiLlmEditorForm extends Component {
return this.testRunning || this.testResult !== null; return this.testRunning || this.testResult !== null;
} }
@computed("args.model.provider")
get canEditURL() { get canEditURL() {
// Explicitly false. return this.args.model.provider === "aws_bedrock";
return this.metaProviderParams.url_editable !== false;
} }
get modulesUsingModel() { get modulesUsingModel() {
@ -227,18 +228,24 @@ export default class AiLlmEditorForm extends Component {
<DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" /> <DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" />
</div> </div>
</div> </div>
{{#each this.metaProviderParams.fields as |field|}} {{#each-in this.metaProviderParams as |field type|}}
<div class="control-group"> <div class="control-group ai-llm-editor-provider-param__{{type}}">
<label>{{I18n.t <label>{{I18n.t
(concat "discourse_ai.llms.provider_fields." field) (concat "discourse_ai.llms.provider_fields." field)
}}</label> }}</label>
{{#if (eq type "checkbox")}}
<Input <Input
@type="text" @type={{type}}
@value={{mut (get @model.provider_params field)}} @checked={{mut (get @model.provider_params field)}}
class="ai-llm-editor-input ai-llm-editor__{{field}}"
/> />
{{else}}
<Input
@type={{type}}
@value={{mut (get @model.provider_params field)}}
/>
{{/if}}
</div> </div>
{{/each}} {{/each-in}}
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label> <label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox <ComboBox

View File

@ -18,6 +18,15 @@
width: 350px; width: 350px;
} }
.ai-llm-editor-provider-param {
&__checkbox {
display: flex;
align-items: flex-start;
flex-direction: row-reverse;
justify-content: left;
}
}
.fk-d-tooltip__icon { .fk-d-tooltip__icon {
padding-left: 0.25em; padding-left: 0.25em;
color: var(--primary-medium); color: var(--primary-medium);

View File

@ -273,6 +273,7 @@ en:
access_key_id: "AWS Bedrock Access key ID" access_key_id: "AWS Bedrock Access key ID"
region: "AWS Bedrock Region" region: "AWS Bedrock Region"
organization: "Optional OpenAI Organization ID" organization: "Optional OpenAI Organization ID"
disable_system_prompt: "Disable system message in prompts"
related_topics: related_topics:
title: "Related Topics" title: "Related Topics"

View File

@ -4,7 +4,6 @@ en:
llm_triage: llm_triage:
title: Triage posts using AI title: Triage posts using AI
description: "Triage posts using a large language model" description: "Triage posts using a large language model"
system_prompt_missing_post_placeholder: "System prompt must contain a placeholder for the post: %%POST%%"
flagged_post: | flagged_post: |
<div>Response from the model:</div> <div>Response from the model:</div>
<p>%%LLM_RESPONSE%%</p> <p>%%LLM_RESPONSE%%</p>

View File

@ -9,17 +9,7 @@ if defined?(DiscourseAutomation)
triggerables %i[post_created_edited] triggerables %i[post_created_edited]
field :system_prompt, field :system_prompt, component: :message, required: false
component: :message,
required: true,
validator: ->(input) do
if !input.include?("%%POST%%")
I18n.t(
"discourse_automation.scriptables.llm_triage.system_prompt_missing_post_placeholder",
)
end
end,
accepts_placeholders: true
field :search_for_text, component: :text, required: true field :search_for_text, component: :text, required: true
field :model, field :model,
component: :choices, component: :choices,

View File

@ -21,15 +21,9 @@ module DiscourseAi
raise ArgumentError, "llm_triage: no action specified!" raise ArgumentError, "llm_triage: no action specified!"
end end
post_template = +"" s_prompt = system_prompt.to_s.sub("%%POST%%", "") # Backwards-compat. We no longer sub this.
post_template << "title: #{post.topic.title}\n" prompt = DiscourseAi::Completions::Prompt.new(s_prompt)
post_template << "#{post.raw}" prompt.push(type: :user, content: "title: #{post.topic.title}\n#{post.raw}")
filled_system_prompt = system_prompt.sub("%%POST%%", post_template)
if filled_system_prompt == system_prompt
raise ArgumentError, "llm_triage: system_prompt does not contain %%POST%% placeholder"
end
result = nil result = nil
@ -37,7 +31,7 @@ module DiscourseAi
result = result =
llm.generate( llm.generate(
filled_system_prompt, prompt,
temperature: 0, temperature: 0,
max_tokens: 700, # ~500 words max_tokens: 700, # ~500 words
user: Discourse.system_user, user: Discourse.system_user,

View File

@ -24,6 +24,18 @@ module DiscourseAi
32_000 32_000
end end
def translate
translated = super
return translated unless llm_model.lookup_custom_param("disable_system_prompt")
system_and_user_msgs = translated.shift(2)
user_msg = system_and_user_msgs.last
user_msg[:content] = [system_and_user_msgs.first[:content], user_msg[:content]].join("\n")
translated.unshift(user_msg)
end
private private
def system_msg(msg) def system_msg(msg)

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::OpenAiCompatible do
context "when system prompts are disabled" do
it "merges the system prompt into the first message" do
system_msg = "This is a system message"
user_msg = "user message"
prompt =
DiscourseAi::Completions::Prompt.new(
system_msg,
messages: [{ type: :user, content: user_msg }],
)
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: true })
translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(1)
expect(translated_messages).to contain_exactly(
{ role: "user", content: [system_msg, user_msg].join("\n") },
)
end
end
context "when system prompts are enabled" do
it "includes system and user messages separately" do
system_msg = "This is a system message"
user_msg = "user message"
prompt =
DiscourseAi::Completions::Prompt.new(
system_msg,
messages: [{ type: :user, content: user_msg }],
)
model = Fabricate(:vllm_model, provider_params: { disable_system_prompt: false })
translated_messages = described_class.new(prompt, model).translate
expect(translated_messages.length).to eq(2)
expect(translated_messages).to contain_exactly(
{ role: "system", content: system_msg },
{ role: "user", content: user_msg },
)
end
end
end

View File

@ -136,6 +136,24 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(created_model.lookup_custom_param("region")).to eq("us-east-1") expect(created_model.lookup_custom_param("region")).to eq("us-east-1")
expect(created_model.lookup_custom_param("access_key_id")).to eq("test") expect(created_model.lookup_custom_param("access_key_id")).to eq("test")
end end
it "supports boolean values" do
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm:
valid_attrs.merge(
provider: "vllm",
provider_params: {
disable_system_prompt: true,
},
),
}
created_model = LlmModel.last
expect(response.status).to eq(201)
expect(created_model.lookup_custom_param("disable_system_prompt")).to eq(true)
end
end end
end end