FEATURE: Mixtral/Mistral/Haiku Automation Support (#571)

Adds new models to automation, and makes LLM output parsing more robust.
This commit is contained in:
Rafael dos Santos Silva 2024-04-11 09:50:46 -03:00 committed by GitHub
parent 23d12c8927
commit 253e0b7b39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 6 deletions

View File

@ -14,6 +14,9 @@ en:
gemini_pro: Gemini Pro gemini_pro: Gemini Pro
claude_3_opus: Claude 3 Opus claude_3_opus: Claude 3 Opus
claude_3_sonnet: Claude 3 Sonnet claude_3_sonnet: Claude 3 Sonnet
claude_3_haiku: Claude 3 Haiku
mixtral_8x7b_instruct_v0_1: Mixtral 8x7B Instruct V0.1
mistral_7b_instruct_v0_2: Mistral 7B Instruct V0.2
scriptables: scriptables:
llm_report: llm_report:
fields: fields:

View File

@ -47,11 +47,6 @@ if defined?(DiscourseAutomation)
search_for_text = fields["search_for_text"]["value"] search_for_text = fields["search_for_text"]["value"]
model = fields["model"]["value"] model = fields["model"]["value"]
if !%w[gpt-4 gpt-3-5-turbo claude-2].include?(model)
Rails.logger.warn("llm_triage: model #{model} is not supported")
next
end
category_id = fields.dig("category", "value") category_id = fields.dig("category", "value")
tags = fields.dig("tags", "value") tags = fields.dig("tags", "value")
hide_topic = fields.dig("hide_topic", "value") hide_topic = fields.dig("hide_topic", "value")

View File

@ -10,6 +10,15 @@ module DiscourseAi
{ id: "claude-2", name: "discourse_automation.ai_models.claude_2" }, { id: "claude-2", name: "discourse_automation.ai_models.claude_2" },
{ id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" }, { id: "claude-3-sonnet", name: "discourse_automation.ai_models.claude_3_sonnet" },
{ id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" }, { id: "claude-3-opus", name: "discourse_automation.ai_models.claude_3_opus" },
{ id: "claude-3-haiku", name: "discourse_automation.ai_models.claude_3_haiku" },
{
id: "mistralai/Mixtral-8x7B-Instruct-v0.1",
name: "discourse_automation.ai_models.mixtral_8x7b_instruct_v0_1",
},
{
id: "mistralai/Mistral-7B-Instruct-v0.2",
name: "discourse_automation.ai_models.mistral_7b_instruct_v0_2",
},
] ]
def self.translate_model(model) def self.translate_model(model)
@ -24,6 +33,14 @@ module DiscourseAi
end end
end end
if model.start_with?("mistral")
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model)
return "vllm:#{model}"
else
return "hugging_face:#{model}"
end
end
raise "Unknown model #{model}" raise "Unknown model #{model}"
end end
end end

View File

@ -43,7 +43,7 @@ module DiscourseAi
user: Discourse.system_user, user: Discourse.system_user,
) )
if result.strip == search_for_text.strip if result.present? && result.strip.downcase.include?(search_for_text)
user = User.find_by_username(canned_reply_user) if canned_reply_user.present? user = User.find_by_username(canned_reply_user) if canned_reply_user.present?
user = user || Discourse.system_user user = user || Discourse.system_user
if canned_reply.present? if canned_reply.present?

View File

@ -84,4 +84,20 @@ describe DiscourseAi::Automation::LlmTriage do
expect(reviewable.target).to eq(post) expect(reviewable.target).to eq(post)
end end
it "can handle garbled output from LLM" do
DiscourseAi::Completions::Llm.with_prepared_responses(["Bad.\n\nYo"]) do
triage(
post: post,
model: "gpt-4",
system_prompt: "test %%POST%%",
search_for_text: "bad",
flag_post: true,
)
end
reviewable = ReviewablePost.last
expect(reviewable&.target).to eq(post)
end
end end