FEATURE: Claude based scanning and OpenAI retries (#243)

llm_triage supported claude 2 in triage, this implements it

OpenAI rate limits frequently, this introduces some exponential
backoff (3 attempts - 3 seconds, 9 and 27)

Also reduces temp of classifiers so they have consistent behavior
This commit is contained in:
Sam 2023-10-05 09:00:45 +11:00 committed by GitHub
parent 84cc369552
commit d87adcebea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 15 deletions

View File

@ -27,14 +27,28 @@ if defined?(DiscourseAutomation)
raise ArgumentError, "llm_triage: system_prompt does not contain %%POST%% placeholder"
end
result = nil
if model == "claude-2"
# allowing double + 10 tokens
# technically maybe just token count is fine, but this will allow for more creative bad responses
result =
DiscourseAi::Inference::AnthropicCompletions.perform!(
filled_system_prompt,
model,
temperature: 0,
max_tokens:
DiscourseAi::Tokenizer::AnthropicTokenizer.tokenize(search_for_text).length * 2 + 10,
).dig(:completion)
else
result =
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ :role => "system", "content" => filled_system_prompt }],
model,
temperature: 0.7,
top_p: 0.9,
max_tokens: 40,
temperature: 0,
max_tokens:
DiscourseAi::Tokenizer::OpenAiTokenizer.tokenize(search_for_text).length * 2 + 10,
).dig(:choices, 0, :message, :content)
end
if result.strip == search_for_text.strip
user = User.find_by_username(canned_reply_user) if canned_reply_user.present?
@ -118,7 +132,7 @@ if defined?(DiscourseAutomation)
search_for_text = fields["search_for_text"]["value"]
model = fields["model"]["value"]
if !%w[gpt-4 gpt-3-5-turbo].include?(model)
if !%w[gpt-4 gpt-3-5-turbo claude-2].include?(model)
Rails.logger.warn("llm_triage: model #{model} is not supported")
next
end

View File

@ -4,6 +4,10 @@ module ::DiscourseAi
module Inference
class OpenAiCompletions
TIMEOUT = 60
DEFAULT_RETRIES = 3
DEFAULT_RETRY_TIMEOUT_SECONDS = 3
RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
CompletionFailed = Class.new(StandardError)
def self.perform!(
@ -13,7 +17,10 @@ module ::DiscourseAi
top_p: nil,
max_tokens: nil,
functions: nil,
user_id: nil
user_id: nil,
retries: DEFAULT_RETRIES,
retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS,
&blk
)
log = nil
response_data = +""
@ -62,11 +69,29 @@ module ::DiscourseAi
request.body = request_body
http.request(request) do |response|
if response.code.to_i != 200
if retries > 0 && response.code.to_i == 429
sleep(retry_timeout)
retries -= 1
retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER
return(
perform!(
messages,
model,
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
functions: functions,
user_id: user_id,
retries: retries,
retry_timeout: retry_timeout,
&blk
)
)
elsif response.code.to_i != 200
Rails.logger.error(
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}"
end
log =
@ -76,7 +101,7 @@ module ::DiscourseAi
user_id: user_id,
)
if !block_given?
if !blk
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
@ -121,7 +146,7 @@ module ::DiscourseAi
response_data << partial.dig(:choices, 0, :delta, :content).to_s
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
yield partial, cancel
blk.call(partial, cancel)
end
end
rescue IOError

View File

@ -26,7 +26,24 @@ describe DiscourseAutomation::LlmTriage do
expect(post.topic.reload.visible).to eq(true)
end
it "can hide topics on triage" do
it "can hide topics on triage with claude" do
stub_request(:post, "https://api.anthropic.com/v1/complete").to_return(
status: 200,
body: { completion: "bad" }.to_json,
)
triage(
post: post,
model: "claude-2",
hide_topic: true,
system_prompt: "test %%POST%%",
search_for_text: "bad",
)
expect(post.topic.reload.visible).to eq(false)
end
it "can hide topics on triage with claude" do
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: { choices: [{ message: { content: "bad" } }] }.to_json,

View File

@ -159,6 +159,50 @@ describe DiscourseAi::Inference::OpenAiCompletions do
)
end
it "supports rate limits" do
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
[
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
{ status: 200, body: { choices: [message: { content: "ok" }] }.to_json, headers: {} },
],
)
completions =
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ role: "user", content: "hello" }],
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
retries: 3,
retry_timeout: 0,
)
expect(completions.dig(:choices, 0, :message, :content)).to eq("ok")
end
it "supports will raise once rate limit is met" do
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
[
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
{ status: 429, body: "", headers: {} },
],
)
expect do
DiscourseAi::Inference::OpenAiCompletions.perform!(
[{ role: "user", content: "hello" }],
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
retries: 3,
retry_timeout: 0,
)
end.to raise_error(DiscourseAi::Inference::OpenAiCompletions::CompletionFailed)
end
it "can complete a trivial prompt" do
response_text = "1. Serenity\\n2. Laughter\\n3. Adventure"
prompt = [role: "user", content: "write 3 words"]