2023-12-18 16:06:01 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Dialects
|
|
|
|
class Dialect
|
|
|
|
class << self
|
|
|
|
def can_translate?(_model_name)
|
|
|
|
raise NotImplemented
|
|
|
|
end
|
|
|
|
|
|
|
|
def dialect_for(model_name)
|
|
|
|
dialects = [
|
|
|
|
DiscourseAi::Completions::Dialects::Claude,
|
|
|
|
DiscourseAi::Completions::Dialects::Llama2Classic,
|
|
|
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
|
|
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
|
|
|
DiscourseAi::Completions::Dialects::Gemini,
|
2023-12-26 12:49:55 -05:00
|
|
|
DiscourseAi::Completions::Dialects::Mixtral,
|
2024-03-05 14:04:37 -05:00
|
|
|
DiscourseAi::Completions::Dialects::ClaudeMessages,
|
2023-12-18 16:06:01 -05:00
|
|
|
]
|
2023-12-18 20:04:15 -05:00
|
|
|
|
2024-01-10 23:56:40 -05:00
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
|
|
dialects << DiscourseAi::Completions::Dialects::Fake
|
|
|
|
end
|
|
|
|
|
2023-12-18 20:04:15 -05:00
|
|
|
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
|
|
|
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
|
|
|
dialect
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
raise NotImplemented
|
|
|
|
end
|
2024-01-12 12:36:44 -05:00
|
|
|
|
|
|
|
def tool_preamble
|
|
|
|
<<~TEXT
|
|
|
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
2024-03-07 14:37:23 -05:00
|
|
|
You may call them like this.
|
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
<function_calls>
|
|
|
|
<invoke>
|
|
|
|
<tool_name>$TOOL_NAME</tool_name>
|
|
|
|
<parameters>
|
|
|
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
|
|
|
...
|
|
|
|
</parameters>
|
|
|
|
</invoke>
|
|
|
|
</function_calls>
|
2024-01-15 02:51:14 -05:00
|
|
|
|
2024-03-07 14:37:23 -05:00
|
|
|
If a parameter type is an array, return a JSON array of values. For example:
|
2024-01-12 12:36:44 -05:00
|
|
|
[1,"two",3.0]
|
2024-01-15 02:51:14 -05:00
|
|
|
|
2024-03-07 14:37:23 -05:00
|
|
|
Always wrap <invoke> calls in <function_calls> tags.
|
|
|
|
You may call multiple function via <invoke> in a single <function_calls> block.
|
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
Here are the tools available:
|
|
|
|
TEXT
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def initialize(generic_prompt, model_name, opts: {})
|
|
|
|
@prompt = generic_prompt
|
|
|
|
@model_name = model_name
|
|
|
|
@opts = opts
|
|
|
|
end
|
|
|
|
|
|
|
|
def translate
|
|
|
|
raise NotImplemented
|
|
|
|
end
|
|
|
|
|
2024-03-05 14:04:37 -05:00
|
|
|
def tool_result_to_xml(message)
|
|
|
|
(<<~TEXT).strip
|
|
|
|
<function_results>
|
|
|
|
<result>
|
|
|
|
<tool_name>#{message[:id]}</tool_name>
|
|
|
|
<json>
|
|
|
|
#{message[:content]}
|
|
|
|
</json>
|
|
|
|
</result>
|
|
|
|
</function_results>
|
|
|
|
TEXT
|
|
|
|
end
|
|
|
|
|
|
|
|
def tool_call_to_xml(message)
|
|
|
|
parsed = JSON.parse(message[:content], symbolize_names: true)
|
|
|
|
parameters = +""
|
|
|
|
|
|
|
|
if parsed[:arguments]
|
|
|
|
parameters << "<parameters>\n"
|
|
|
|
parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}</#{k}>\n" }
|
|
|
|
parameters << "</parameters>\n"
|
|
|
|
end
|
|
|
|
|
|
|
|
(<<~TEXT).strip
|
|
|
|
<function_calls>
|
|
|
|
<invoke>
|
|
|
|
<tool_name>#{parsed[:name]}</tool_name>
|
|
|
|
#{parameters}</invoke>
|
|
|
|
</function_calls>
|
|
|
|
TEXT
|
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def tools
|
|
|
|
tools = +""
|
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
prompt.tools.each do |function|
|
2023-12-18 16:06:01 -05:00
|
|
|
parameters = +""
|
|
|
|
if function[:parameters].present?
|
|
|
|
function[:parameters].each do |parameter|
|
|
|
|
parameters << <<~PARAMETER
|
|
|
|
<parameter>
|
|
|
|
<name>#{parameter[:name]}</name>
|
|
|
|
<type>#{parameter[:type]}</type>
|
|
|
|
<description>#{parameter[:description]}</description>
|
|
|
|
<required>#{parameter[:required]}</required>
|
|
|
|
PARAMETER
|
|
|
|
if parameter[:enum]
|
|
|
|
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
|
|
|
|
end
|
|
|
|
parameters << "</parameter>\n"
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
tools << <<~TOOLS
|
|
|
|
<tool_description>
|
|
|
|
<tool_name>#{function[:name]}</tool_name>
|
|
|
|
<description>#{function[:description]}</description>
|
|
|
|
<parameters>
|
|
|
|
#{parameters}</parameters>
|
|
|
|
</tool_description>
|
|
|
|
TOOLS
|
|
|
|
end
|
|
|
|
|
|
|
|
tools
|
|
|
|
end
|
|
|
|
|
|
|
|
def conversation_context
|
|
|
|
raise NotImplemented
|
|
|
|
end
|
|
|
|
|
|
|
|
def max_prompt_tokens
|
|
|
|
raise NotImplemented
|
|
|
|
end
|
|
|
|
|
2024-03-01 15:53:21 -05:00
|
|
|
attr_reader :prompt
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
private
|
|
|
|
|
2024-03-01 15:53:21 -05:00
|
|
|
attr_reader :model_name, :opts
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
def trim_messages(messages)
|
2023-12-18 16:06:01 -05:00
|
|
|
prompt_limit = max_prompt_tokens
|
2024-01-12 12:36:44 -05:00
|
|
|
current_token_count = 0
|
2023-12-26 12:49:55 -05:00
|
|
|
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
trimmed_messages = []
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
range = (0..-1)
|
|
|
|
if messages.dig(0, :type) == :system
|
|
|
|
system_message = messages[0]
|
|
|
|
trimmed_messages << system_message
|
|
|
|
current_token_count += calculate_message_token(system_message)
|
|
|
|
range = (1..-1)
|
|
|
|
end
|
2024-01-04 08:44:07 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
reversed_trimmed_msgs = []
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
messages[range].reverse.each do |msg|
|
|
|
|
break if current_token_count >= prompt_limit
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
message_tokens = calculate_message_token(msg)
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
dupped_msg = msg.dup
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
# Don't trim tool call metadata.
|
|
|
|
if msg[:type] == :tool_call
|
|
|
|
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
|
|
|
|
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
|
|
reversed_trimmed_msgs << dupped_msg
|
|
|
|
next
|
|
|
|
end
|
|
|
|
|
|
|
|
# Trimming content to make sure we respect token limit.
|
|
|
|
while dupped_msg[:content].present? &&
|
|
|
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
|
|
|
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
|
|
|
message_tokens = calculate_message_token(dupped_msg)
|
|
|
|
end
|
|
|
|
|
|
|
|
next if dupped_msg[:content].blank?
|
|
|
|
|
|
|
|
current_token_count += message_tokens + per_message_overhead
|
|
|
|
|
|
|
|
reversed_trimmed_msgs << dupped_msg
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-15 21:48:00 -05:00
|
|
|
reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool
|
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def per_message_overhead
|
|
|
|
0
|
|
|
|
end
|
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
def calculate_message_token(msg)
|
|
|
|
self.class.tokenizer.size(msg[:content].to_s)
|
2024-01-08 08:28:03 -05:00
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-08 08:28:03 -05:00
|
|
|
def build_tools_prompt
|
2024-01-12 12:36:44 -05:00
|
|
|
return "" if prompt.tools.blank?
|
2024-01-08 08:28:03 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
(<<~TEXT).strip
|
2024-01-08 08:28:03 -05:00
|
|
|
#{self.class.tool_preamble}
|
2023-12-18 16:06:01 -05:00
|
|
|
<tools>
|
|
|
|
#{tools}</tools>
|
|
|
|
TEXT
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|