FIX: Include provider in automation models (#446)

This commit is contained in:
Roman Rizzi 2024-01-29 18:07:29 -03:00 committed by GitHub
parent 0634b85a81
commit bae71eb047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 8 deletions

View File

@ -30,7 +30,7 @@ module DiscourseAi
result = nil
llm = DiscourseAi::Completions::Llm.proxy(model)
llm = DiscourseAi::Completions::Llm.proxy(translate_model(model))
result =
llm.generate(
@ -67,6 +67,17 @@ module DiscourseAi
post.topic.update!(visible: false) if hide_topic
end
end
def self.translate_model(model)
return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model != "claude-2"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
end
end
end
end

View File

@ -60,7 +60,7 @@ module DiscourseAi
I18n.t("discourse_automation.llm_report.title")
end
@model = model
@llm = DiscourseAi::Completions::Llm.proxy(model)
@llm = DiscourseAi::Completions::Llm.proxy(translate_model(model))
@category_ids = category_ids
@tags = tags
@allow_secure_categories = allow_secure_categories
@ -176,6 +176,17 @@ module DiscourseAi
end
end
end
def translate_model(model)
return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model != "claude-2"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
end
end
end
end

View File

@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
triage(
post: post,
model: "fake:fake",
model: "gpt-4",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -24,7 +24,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "fake:fake",
model: "gpt-4",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -40,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "fake:fake",
model: "gpt-4",
category_id: category.id,
system_prompt: "test %%POST%%",
search_for_text: "bad",
@ -55,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage(
post: post,
model: "fake:fake",
model: "gpt-4",
system_prompt: "test %%POST%%",
search_for_text: "bad",
canned_reply: "test canned reply 123",

View File

@ -22,7 +22,7 @@ module DiscourseAi
sender_username: user.username,
receivers: ["fake@discourse.com"],
title: "test report %DATE%",
model: "fake:fake",
model: "gpt-4",
category_ids: nil,
tags: nil,
allow_secure_categories: false,
@ -48,7 +48,7 @@ module DiscourseAi
sender_username: user.username,
receivers: [receiver.username],
title: "test report",
model: "fake:fake",
model: "gpt-4",
category_ids: nil,
tags: nil,
allow_secure_categories: false,