mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-03 03:43:26 +00:00
This PR adds support for disabling further tool calls by setting tool_choice to :none across all supported LLM providers: - OpenAI: Uses "none" tool_choice parameter - Anthropic: Uses {type: "none"} and adds a prefill message to prevent confusion - Gemini: Sets function_calling_config mode to "NONE" - AWS Bedrock: Doesn't natively support tool disabling, so adds a prefill message We previously used to disable tool calls by simply removing tool definitions, but this would cause errors with some providers. This implementation uses the supported method appropriate for each provider while providing a fallback for Bedrock. Co-authored-by: Natalie Tay <natalie.tay@gmail.com> * remove stray puts * cleaner chain breaker for last tool call (works in thinking) remove unused code * improve test --------- Co-authored-by: Natalie Tay <natalie.tay@gmail.com>
234 lines
6.7 KiB
Ruby
234 lines
6.7 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Dialects
|
|
class Dialect
|
|
class << self
|
|
def can_translate?(llm_model)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def all_dialects
|
|
[
|
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
|
DiscourseAi::Completions::Dialects::Gemini,
|
|
DiscourseAi::Completions::Dialects::Claude,
|
|
DiscourseAi::Completions::Dialects::Command,
|
|
DiscourseAi::Completions::Dialects::Ollama,
|
|
DiscourseAi::Completions::Dialects::Mistral,
|
|
DiscourseAi::Completions::Dialects::Nova,
|
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
|
]
|
|
end
|
|
|
|
def dialect_for(llm_model)
|
|
dialects = []
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
dialects = [DiscourseAi::Completions::Dialects::Fake]
|
|
end
|
|
|
|
dialects = dialects.concat(all_dialects)
|
|
|
|
dialect = dialects.find { |d| d.can_translate?(llm_model) }
|
|
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
|
|
|
dialect
|
|
end
|
|
end
|
|
|
|
def initialize(generic_prompt, llm_model, opts: {})
|
|
@prompt = generic_prompt
|
|
@opts = opts
|
|
@llm_model = llm_model
|
|
end
|
|
|
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
|
|
|
def native_tool_support?
|
|
false
|
|
end
|
|
|
|
def vision_support?
|
|
llm_model.vision_enabled?
|
|
end
|
|
|
|
def tools
|
|
@tools ||= tools_dialect.translated_tools
|
|
end
|
|
|
|
def tool_choice
|
|
prompt.tool_choice
|
|
end
|
|
|
|
def self.no_more_tool_calls_text
|
|
# note, Anthropic must never prefill with an ending whitespace
|
|
"I WILL NOT USE TOOLS IN THIS REPLY, user expressed they wanted to stop using tool calls.\nHere is the best, complete, answer I can come up with given the information I have."
|
|
end
|
|
|
|
def self.no_more_tool_calls_text_user
|
|
"DO NOT USE TOOLS IN YOUR REPLY. Return the best answer you can given the information I supplied you."
|
|
end
|
|
|
|
def no_more_tool_calls_text
|
|
self.class.no_more_tool_calls_text
|
|
end
|
|
|
|
def no_more_tool_calls_text_user
|
|
self.class.no_more_tool_calls_text_user
|
|
end
|
|
|
|
def translate
|
|
messages = trim_messages(prompt.messages)
|
|
last_message = messages.last
|
|
inject_done_on_last_tool_call = false
|
|
|
|
if !native_tool_support? && last_message && last_message[:type].to_sym == :tool &&
|
|
prompt.tool_choice == :none
|
|
inject_done_on_last_tool_call = true
|
|
end
|
|
|
|
translated =
|
|
messages
|
|
.map do |msg|
|
|
case msg[:type].to_sym
|
|
when :system
|
|
system_msg(msg)
|
|
when :user
|
|
user_msg(msg)
|
|
when :model
|
|
model_msg(msg)
|
|
when :tool
|
|
if inject_done_on_last_tool_call && msg == last_message
|
|
tools_dialect.inject_done { tool_msg(msg) }
|
|
else
|
|
tool_msg(msg)
|
|
end
|
|
when :tool_call
|
|
tool_call_msg(msg)
|
|
else
|
|
raise ArgumentError, "Unknown message type: #{msg[:type]}"
|
|
end
|
|
end
|
|
.compact
|
|
|
|
translated
|
|
end
|
|
|
|
def conversation_context
|
|
raise NotImplemented
|
|
end
|
|
|
|
def max_prompt_tokens
|
|
raise NotImplemented
|
|
end
|
|
|
|
attr_reader :prompt
|
|
|
|
private
|
|
|
|
attr_reader :opts, :llm_model
|
|
|
|
def trim_messages(messages)
|
|
prompt_limit = max_prompt_tokens
|
|
current_token_count = 0
|
|
message_step_size = (prompt_limit / 25).to_i * -1
|
|
|
|
trimmed_messages = []
|
|
|
|
range = (0..-1)
|
|
if messages.dig(0, :type) == :system
|
|
max_system_tokens = prompt_limit * 0.6
|
|
system_message = messages[0]
|
|
system_size = calculate_message_token(system_message)
|
|
|
|
if system_size > max_system_tokens
|
|
system_message[:content] = tokenizer.truncate(
|
|
system_message[:content],
|
|
max_system_tokens,
|
|
)
|
|
end
|
|
|
|
trimmed_messages << system_message
|
|
current_token_count += calculate_message_token(system_message)
|
|
range = (1..-1)
|
|
end
|
|
|
|
reversed_trimmed_msgs = []
|
|
|
|
messages[range].reverse.each do |msg|
|
|
break if current_token_count >= prompt_limit
|
|
|
|
message_tokens = calculate_message_token(msg)
|
|
|
|
dupped_msg = msg.dup
|
|
|
|
# Don't trim tool call metadata.
|
|
if msg[:type] == :tool_call
|
|
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
reversed_trimmed_msgs << dupped_msg
|
|
next
|
|
end
|
|
|
|
# Trimming content to make sure we respect token limit.
|
|
while dupped_msg[:content].present? &&
|
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
|
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
|
message_tokens = calculate_message_token(dupped_msg)
|
|
end
|
|
|
|
next if dupped_msg[:content].blank?
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
|
|
reversed_trimmed_msgs << dupped_msg
|
|
end
|
|
|
|
reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool
|
|
|
|
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
|
end
|
|
|
|
def per_message_overhead
|
|
0
|
|
end
|
|
|
|
def calculate_message_token(msg)
|
|
llm_model.tokenizer_class.size(msg[:content].to_s)
|
|
end
|
|
|
|
def tools_dialect
|
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
|
|
end
|
|
|
|
def system_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def model_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def user_msg(msg)
|
|
raise NotImplemented
|
|
end
|
|
|
|
def tool_call_msg(msg)
|
|
new_content = tools_dialect.from_raw_tool_call(msg)
|
|
msg = msg.merge(content: new_content)
|
|
model_msg(msg)
|
|
end
|
|
|
|
def tool_msg(msg)
|
|
new_content = tools_dialect.from_raw_tool(msg)
|
|
msg = msg.merge(content: new_content)
|
|
user_msg(msg)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|