FIX: improve token counting (#145)

Previously we were not counting functions correctly and not
accounting for minimum token count per message

This corrects both issues and improves documentation internally
This commit is contained in:
Sam 2023-08-22 08:36:41 +10:00 committed by GitHub
parent ea5a443588
commit 78f61914c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 11 deletions

View File

@ -182,13 +182,17 @@ module DiscourseAi
Discourse.warn_exception(e, message: "ai-bot: Reply failed") Discourse.warn_exception(e, message: "ai-bot: Reply failed")
end end
def extra_tokens_per_message
0
end
def bot_prompt_with_topic_context(post, prompt: "topic") def bot_prompt_with_topic_context(post, prompt: "topic")
messages = [] messages = []
conversation = conversation_context(post) conversation = conversation_context(post)
rendered_system_prompt = system_prompt(post) rendered_system_prompt = system_prompt(post)
total_prompt_tokens = tokenize(rendered_system_prompt).length total_prompt_tokens = tokenize(rendered_system_prompt).length + extra_tokens_per_message
messages = messages =
conversation.reduce([]) do |memo, (raw, username, function)| conversation.reduce([]) do |memo, (raw, username, function)|
@ -196,18 +200,20 @@ module DiscourseAi
tokens = tokenize(raw.to_s) tokens = tokenize(raw.to_s)
while !raw.blank? && tokens.length + total_prompt_tokens > prompt_limit while !raw.blank? &&
tokens.length + total_prompt_tokens + extra_tokens_per_message > prompt_limit
raw = raw[0..-100] || "" raw = raw[0..-100] || ""
tokens = tokenize(raw.to_s) tokens = tokenize(raw.to_s)
end end
next(memo) if raw.blank? next(memo) if raw.blank?
total_prompt_tokens += tokens.length total_prompt_tokens += tokens.length + extra_tokens_per_message
memo.unshift(build_message(username, raw, function: !!function)) memo.unshift(build_message(username, raw, function: !!function))
end end
messages.unshift(build_message(bot_user.username, rendered_system_prompt, system: true)) messages.unshift(build_message(bot_user.username, rendered_system_prompt, system: true))
messages messages
end end

View File

@ -13,16 +13,16 @@ module DiscourseAi
end end
def prompt_limit def prompt_limit
# note GPT counts both reply and request tokens in limits... # note this is about 100 tokens over, OpenAI have a more optimal representation
# also allow for an extra 500 or so spare tokens @function_size ||= tokenize(available_functions.to_json).length
#
# 2500 are the max reply tokens # provide a buffer of 50 tokens in case our counting is off
# Then we have 450 or so for the full function suite buffer = @function_size + reply_params[:max_tokens] + 50
# 100 additional for growth around function calls
if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID
8192 - 3050 8192 - buffer
else else
16_384 - 3050 16_384 - buffer
end end
end end
@ -32,6 +32,11 @@ module DiscourseAi
{ temperature: 0.4, top_p: 0.9, max_tokens: 2500 } { temperature: 0.4, top_p: 0.9, max_tokens: 2500 }
end end
def extra_tokens_per_message
# open ai defines about 4 tokens per message of overhead
4
end
def submit_prompt( def submit_prompt(
prompt, prompt,
prefer_low_cost: false, prefer_low_cost: false,