FIX: Include provider in automation models (#446)

This commit is contained in:
Roman Rizzi 2024-01-29 18:07:29 -03:00 committed by GitHub
parent 0634b85a81
commit bae71eb047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 8 deletions

View File

@ -30,7 +30,7 @@ module DiscourseAi
result = nil result = nil
llm = DiscourseAi::Completions::Llm.proxy(model) llm = DiscourseAi::Completions::Llm.proxy(translate_model(model))
result = result =
llm.generate( llm.generate(
@ -67,6 +67,17 @@ module DiscourseAi
post.topic.update!(visible: false) if hide_topic post.topic.update!(visible: false) if hide_topic
end end
end end
def self.translate_model(model)
return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model != "claude-2"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
end
end end
end end
end end

View File

@ -60,7 +60,7 @@ module DiscourseAi
I18n.t("discourse_automation.llm_report.title") I18n.t("discourse_automation.llm_report.title")
end end
@model = model @model = model
@llm = DiscourseAi::Completions::Llm.proxy(model) @llm = DiscourseAi::Completions::Llm.proxy(translate_model(model))
@category_ids = category_ids @category_ids = category_ids
@tags = tags @tags = tags
@allow_secure_categories = allow_secure_categories @allow_secure_categories = allow_secure_categories
@ -176,6 +176,17 @@ module DiscourseAi
end end
end end
end end
def translate_model(model)
return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model != "claude-2"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
end
end end
end end
end end

View File

@ -10,7 +10,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["good"]) do
triage( triage(
post: post, post: post,
model: "fake:fake", model: "gpt-4",
hide_topic: true, hide_topic: true,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -24,7 +24,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "fake:fake", model: "gpt-4",
hide_topic: true, hide_topic: true,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -40,7 +40,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "fake:fake", model: "gpt-4",
category_id: category.id, category_id: category.id,
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
@ -55,7 +55,7 @@ describe DiscourseAi::Automation::LlmTriage do
DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["bad"]) do
triage( triage(
post: post, post: post,
model: "fake:fake", model: "gpt-4",
system_prompt: "test %%POST%%", system_prompt: "test %%POST%%",
search_for_text: "bad", search_for_text: "bad",
canned_reply: "test canned reply 123", canned_reply: "test canned reply 123",

View File

@ -22,7 +22,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: ["fake@discourse.com"], receivers: ["fake@discourse.com"],
title: "test report %DATE%", title: "test report %DATE%",
model: "fake:fake", model: "gpt-4",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,
@ -48,7 +48,7 @@ module DiscourseAi
sender_username: user.username, sender_username: user.username,
receivers: [receiver.username], receivers: [receiver.username],
title: "test report", title: "test report",
model: "fake:fake", model: "gpt-4",
category_ids: nil, category_ids: nil,
tags: nil, tags: nil,
allow_secure_categories: false, allow_secure_categories: false,