2024-01-12 12:36:44 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
class Prompt
|
|
|
|
INVALID_TURN = Class.new(StandardError)
|
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
attr_reader :messages
|
2024-03-01 15:53:21 -05:00
|
|
|
attr_accessor :tools, :topic_id, :post_id
|
|
|
|
|
|
|
|
def initialize(
|
|
|
|
system_message_text = nil,
|
|
|
|
messages: [],
|
|
|
|
tools: [],
|
|
|
|
skip_validations: false,
|
|
|
|
topic_id: nil,
|
|
|
|
post_id: nil
|
|
|
|
)
|
2024-01-12 12:36:44 -05:00
|
|
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
|
|
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
|
|
|
|
2024-03-01 15:53:21 -05:00
|
|
|
@topic_id = topic_id
|
|
|
|
@post_id = post_id
|
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
@messages = []
|
2024-02-19 12:56:28 -05:00
|
|
|
@skip_validations = skip_validations
|
2024-01-15 02:51:14 -05:00
|
|
|
|
|
|
|
if system_message_text
|
|
|
|
system_message = { type: :system, content: system_message_text }
|
|
|
|
@messages << system_message
|
|
|
|
end
|
|
|
|
|
|
|
|
@messages.concat(messages)
|
2024-01-12 12:36:44 -05:00
|
|
|
|
|
|
|
@messages.each { |message| validate_message(message) }
|
|
|
|
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
|
|
|
|
|
|
|
@tools = tools
|
|
|
|
end
|
|
|
|
|
2024-03-08 16:46:40 -05:00
|
|
|
def push(type:, content:, id: nil, name: nil)
|
2024-01-12 12:36:44 -05:00
|
|
|
return if type == :system
|
|
|
|
new_message = { type: type, content: content }
|
2024-03-08 16:46:40 -05:00
|
|
|
new_message[:name] = name.to_s if name
|
2024-01-15 21:48:00 -05:00
|
|
|
new_message[:id] = id.to_s if id
|
2024-01-12 12:36:44 -05:00
|
|
|
|
|
|
|
validate_message(new_message)
|
|
|
|
validate_turn(messages.last, new_message)
|
|
|
|
|
|
|
|
messages << new_message
|
|
|
|
end
|
|
|
|
|
2024-03-07 14:37:23 -05:00
|
|
|
def has_tools?
|
|
|
|
tools.present?
|
|
|
|
end
|
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
private
|
|
|
|
|
|
|
|
def validate_message(message)
|
2024-02-19 12:56:28 -05:00
|
|
|
return if @skip_validations
|
2024-01-12 12:36:44 -05:00
|
|
|
valid_types = %i[system user model tool tool_call]
|
|
|
|
if !valid_types.include?(message[:type])
|
|
|
|
raise ArgumentError, "message type must be one of #{valid_types}"
|
|
|
|
end
|
|
|
|
|
2024-03-08 16:46:40 -05:00
|
|
|
valid_keys = %i[type content id name]
|
2024-01-12 12:36:44 -05:00
|
|
|
if (invalid_keys = message.keys - valid_keys).any?
|
|
|
|
raise ArgumentError, "message contains invalid keys: #{invalid_keys}"
|
|
|
|
end
|
|
|
|
|
|
|
|
raise ArgumentError, "message content must be a string" if !message[:content].is_a?(String)
|
|
|
|
end
|
|
|
|
|
|
|
|
def validate_turn(last_turn, new_turn)
|
2024-02-19 12:56:28 -05:00
|
|
|
return if @skip_validations
|
2024-01-12 12:36:44 -05:00
|
|
|
valid_types = %i[tool tool_call model user]
|
|
|
|
raise INVALID_TURN if !valid_types.include?(new_turn[:type])
|
|
|
|
|
|
|
|
if last_turn[:type] == :system && %i[tool tool_call model].include?(new_turn[:type])
|
|
|
|
raise INVALID_TURN
|
|
|
|
end
|
|
|
|
|
|
|
|
raise INVALID_TURN if new_turn[:type] == :tool && last_turn[:type] != :tool_call
|
|
|
|
raise INVALID_TURN if new_turn[:type] == :model && last_turn[:type] == :model
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|