106 lines
3.2 KiB
Ruby
106 lines
3.2 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
class Prompt
|
|
INVALID_TURN = Class.new(StandardError)
|
|
|
|
attr_reader :messages
|
|
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
|
|
|
|
def initialize(
|
|
system_message_text = nil,
|
|
messages: [],
|
|
tools: [],
|
|
topic_id: nil,
|
|
post_id: nil,
|
|
max_pixels: nil,
|
|
tool_choice: nil
|
|
)
|
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
|
|
|
@max_pixels = max_pixels || 1_048_576
|
|
|
|
@topic_id = topic_id
|
|
@post_id = post_id
|
|
|
|
@messages = []
|
|
|
|
if system_message_text
|
|
system_message = { type: :system, content: system_message_text }
|
|
@messages << system_message
|
|
end
|
|
|
|
@messages.concat(messages)
|
|
|
|
@messages.each { |message| validate_message(message) }
|
|
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
|
|
|
@tools = tools
|
|
@tool_choice = tool_choice
|
|
end
|
|
|
|
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
|
|
return if type == :system
|
|
new_message = { type: type, content: content }
|
|
new_message[:name] = name.to_s if name
|
|
new_message[:id] = id.to_s if id
|
|
new_message[:upload_ids] = upload_ids if upload_ids
|
|
|
|
validate_message(new_message)
|
|
validate_turn(messages.last, new_message)
|
|
|
|
messages << new_message
|
|
end
|
|
|
|
def has_tools?
|
|
tools.present?
|
|
end
|
|
|
|
# helper method to get base64 encoded uploads
|
|
# at the correct dimentions
|
|
def encoded_uploads(message)
|
|
return [] if message[:upload_ids].blank?
|
|
UploadEncoder.encode(upload_ids: message[:upload_ids], max_pixels: max_pixels)
|
|
end
|
|
|
|
private
|
|
|
|
def validate_message(message)
|
|
valid_types = %i[system user model tool tool_call]
|
|
if !valid_types.include?(message[:type])
|
|
raise ArgumentError, "message type must be one of #{valid_types}"
|
|
end
|
|
|
|
valid_keys = %i[type content id name upload_ids]
|
|
if (invalid_keys = message.keys - valid_keys).any?
|
|
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
|
|
end
|
|
|
|
if message[:type] == :upload_ids && !message[:upload_ids].is_a?(Array)
|
|
raise ArgumentError, "upload_ids must be an array of ids"
|
|
end
|
|
|
|
if message[:upload_ids].present? && message[:type] != :user
|
|
raise ArgumentError, "upload_ids are only supported for users"
|
|
end
|
|
|
|
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
|
|
end
|
|
|
|
def validate_turn(last_turn, new_turn)
|
|
valid_types = %i[tool tool_call model user]
|
|
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
|
|
|
|
if last_turn[:type] == :system && %i[tool tool_call model].include?(new_turn[:type])
|
|
raise INVALID_TURN
|
|
end
|
|
|
|
raise INVALID_TURN if new_turn[:type] == :tool && last_turn[:type] != :tool_call
|
|
raise INVALID_TURN if new_turn[:type] == :model && last_turn[:type] == :model
|
|
end
|
|
end
|
|
end
|
|
end
|