diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index b0893c91..a4348df3 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -89,6 +89,9 @@ en: max_post_tokens: label: "Max Post Tokens" description: "The maximum number of tokens to scan using LLM triage" + stop_sequences: + label: "Stop Sequences" + description: "Instruct the model to halt token generation when arriving at one of these values" search_for_text: label: "Search for text" description: "If the following text appears in the LLM reply, apply these actions" diff --git a/discourse_automation/llm_triage.rb b/discourse_automation/llm_triage.rb index 8990c8ed..a0bd21ff 100644 --- a/discourse_automation/llm_triage.rb +++ b/discourse_automation/llm_triage.rb @@ -12,6 +12,7 @@ if defined?(DiscourseAutomation) field :system_prompt, component: :message, required: false field :search_for_text, component: :text, required: true field :max_post_tokens, component: :text + field :stop_sequences, component: :text_list, required: false field :model, component: :choices, required: true, @@ -55,6 +56,8 @@ if defined?(DiscourseAutomation) max_post_tokens = nil if max_post_tokens <= 0 + stop_sequences = fields.dig("stop_sequences", "value") + if post.topic.private_message? include_personal_messages = fields.dig("include_personal_messages", "value") next if !include_personal_messages @@ -88,6 +91,7 @@ if defined?(DiscourseAutomation) flag_post: flag_post, flag_type: flag_type.to_s.to_sym, max_post_tokens: max_post_tokens, + stop_sequences: stop_sequences, automation: self.automation, ) rescue => e diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index 5c58b9d2..d8bffa73 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -16,7 +16,8 @@ module DiscourseAi flag_post: nil, flag_type: nil, automation: nil, - max_post_tokens: nil + max_post_tokens: nil, + stop_sequences: nil ) if category_id.blank? && tags.blank? && canned_reply.blank? && hide_topic.blank? && flag_post.blank? @@ -42,6 +43,7 @@ module DiscourseAi temperature: 0, max_tokens: 700, # ~500 words user: Discourse.system_user, + stop_sequences: stop_sequences, feature_name: "llm_triage", feature_context: { automation_id: automation&.id, diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index c62e0bdd..27934818 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -17,7 +17,7 @@ module DiscourseAi model_params end - attr_reader :responses, :completions, :dialect + attr_reader :responses, :completions, :dialect, :model_params def prompt_messages dialect.prompt.messages @@ -26,12 +26,13 @@ module DiscourseAi def perform_completion!( dialect, _user, - _model_params, + model_params, feature_name: nil, feature_context: nil, partial_tool_calls: false ) @dialect = dialect + @model_params = model_params response = responses[completions] if response.nil? raise CANNED_RESPONSE_ERROR, diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb index a966cfb5..7e434603 100644 --- a/spec/lib/modules/automation/llm_triage_spec.rb +++ b/spec/lib/modules/automation/llm_triage_spec.rb @@ -180,4 +180,22 @@ describe DiscourseAi::Automation::LlmTriage do expect(triage_prompt.messages.last[:upload_ids]).to contain_exactly(post_upload.id) end end + + it "includes stop_sequences in the completion call" do + sequences = %w[GOOD BAD] + + DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do |spy| + triage( + post: post, + model: "custom:#{llm_model.id}", + system_prompt: "test %%POST%%", + search_for_text: "bad", + flag_post: true, + automation: nil, + stop_sequences: sequences, + ) + + expect(spy.model_params[:stop_sequences]).to contain_exactly(*sequences) + end + end end