From 253e0b7b39c5217e6e278f34f63a18cc610ce811 Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Thu, 11 Apr 2024 09:50:46 -0300 Subject: [PATCH] FEATURE: Mixtral/Mistral/Haiku Automation Support (#571) Adds new models to automation, and makes LLM output parsing more robust. --- config/locales/client.en.yml | 3 +++ discourse_automation/llm_triage.rb | 5 ----- lib/automation.rb | 17 +++++++++++++++++ lib/automation/llm_triage.rb | 2 +- spec/lib/modules/automation/llm_triage_spec.rb | 16 ++++++++++++++++ 5 files changed, 37 insertions(+), 6 deletions(-) diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index bad17fc3..967f4476 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -14,6 +14,9 @@ en: gemini_pro: Gemini Pro claude_3_opus: Claude 3 Opus claude_3_sonnet: Claude 3 Sonnet + claude_3_haiku: Claude 3 Haiku + mixtral_8x7b_instruct_v0_1: Mixtral 8x7B Instruct V0.1 + mistral_7b_instruct_v0_2: Mistral 7B Instruct V0.2 scriptables: llm_report: fields: diff --git a/discourse_automation/llm_triage.rb b/discourse_automation/llm_triage.rb index 6b54f502..b0b58629 100644 --- a/discourse_automation/llm_triage.rb +++ b/discourse_automation/llm_triage.rb @@ -47,11 +47,6 @@ if defined?(DiscourseAutomation) search_for_text = fields["search_for_text"]["value"] model = fields["model"]["value"] - if !%w[gpt-4 gpt-3-5-turbo claude-2].include?(model) - Rails.logger.warn("llm_triage: model #{model} is not supported") - next - end - category_id = fields.dig("category", "value") tags = fields.dig("tags", "value") hide_topic = fields.dig("hide_topic", "value") diff --git a/lib/automation.rb b/lib/automation.rb index 02efdce7..15426ef8 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -10,6 +10,15 @@ module DiscourseAi { id: "claude-2", name: "discourse_automation.ai_models.claude_2" }, { id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" }, { id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" }, + { id: "claude-3-haiku", name: "discourse_automation.ai_models.claude_3_haiku" }, + { + id: "mistralai/Mixtral-8x7B-Instruct-v0.1", + name: "discourse_automation.ai_models.mixtral_8x7b_instruct_v0_1", + }, + { + id: "mistralai/Mistral-7B-Instruct-v0.2", + name: "discourse_automation.ai_models.mistral_7b_instruct_v0_2", + }, ] def self.translate_model(model) @@ -24,6 +33,14 @@ module DiscourseAi end end + if model.start_with?("mistral") + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model) + return "vllm:#{model}" + else + return "hugging_face:#{model}" + end + end + raise "Unknown model #{model}" end end diff --git a/lib/automation/llm_triage.rb b/lib/automation/llm_triage.rb index 065fa837..4bdcc257 100644 --- a/lib/automation/llm_triage.rb +++ b/lib/automation/llm_triage.rb @@ -43,7 +43,7 @@ module DiscourseAi user: Discourse.system_user, ) - if result.strip == search_for_text.strip + if result.present? && result.strip.downcase.include?(search_for_text) user = User.find_by_username(canned_reply_user) if canned_reply_user.present? user = user || Discourse.system_user if canned_reply.present? diff --git a/spec/lib/modules/automation/llm_triage_spec.rb b/spec/lib/modules/automation/llm_triage_spec.rb index 1f981ae2..619b1ad2 100644 --- a/spec/lib/modules/automation/llm_triage_spec.rb +++ b/spec/lib/modules/automation/llm_triage_spec.rb @@ -84,4 +84,20 @@ describe DiscourseAi::Automation::LlmTriage do expect(reviewable.target).to eq(post) end + + it "can handle garbled output from LLM" do + DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do + triage( + post: post, + model: "gpt-4", + system_prompt: "test %%POST%%", + search_for_text: "bad", + flag_post: true, + ) + end + + reviewable = ReviewablePost.last + + expect(reviewable&.target).to eq(post) + end end