Sam 545500b329
FEATURE: allows forced LLM tool use (#818)
* FEATURE: allows forced LLM tool use

Sometimes we need to force LLMs to use tools, for example in RAG
like use cases we may want to force an unconditional search.

The new framework allows you backend to force tool usage.

Front end commit to follow

* UI for forcing tools now works, but it does not react right

* fix bugs

* fix tests, this is now ready for review
2024-10-05 09:46:57 +10:00

106 lines
3.2 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
class Prompt
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
def initialize(
system_message_text = nil,
messages: [],
tools: [],
topic_id: nil,
post_id: nil,
max_pixels: nil,
tool_choice: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@max_pixels = max_pixels || 1_048_576
@topic_id = topic_id
@post_id = post_id
@messages = []
if system_message_text
system_message = { type: :system, content: system_message_text }
@messages << system_message
end
@messages.concat(messages)
@messages.each { |message| validate_message(message) }
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools
@tool_choice = tool_choice
end
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
return if type == :system
new_message = { type: type, content: content }
new_message[:name] = name.to_s if name
new_message[:id] = id.to_s if id
new_message[:upload_ids] = upload_ids if upload_ids
validate_message(new_message)
validate_turn(messages.last, new_message)
messages << new_message
end
def has_tools?
tools.present?
end
# helper method to get base64 encoded uploads
# at the correct dimentions
def encoded_uploads(message)
return [] if message[:upload_ids].blank?
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
end
private
def validate_message(message)
valid_types = %i[system user model tool tool_call]
if !valid_types.include?(message[:type])
raise ArgumentError, "message type must be one of #{valid_types}"
end
valid_keys = %i[type content id name upload_ids]
if (invalid_keys = message.keys - valid_keys).any?
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
end
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
raise ArgumentError, "upload_ids must be an array of ids"
end
if message[:upload_ids].present? && message[:type] != :user
raise ArgumentError, "upload_ids are only supported for users"
end
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
end
def validate_turn(last_turn, new_turn)
valid_types = %i[tool tool_call model user]
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
if last_turn[:type] == :system && %i[tool tool_call model].include?(new_turn[:type])
raise INVALID_TURN
end
raise INVALID_TURN if new_turn[:type] == :tool && last_turn[:type] != :tool_call
raise INVALID_TURN if new_turn[:type] == :model && last_turn[:type] == :model
end
end
end
end