2023-11-23 10:58:54 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Endpoints
|
|
|
|
class CannedResponse
|
|
|
|
CANNED_RESPONSE_ERROR = Class.new(StandardError)
|
|
|
|
|
|
|
|
def initialize(responses)
|
|
|
|
@responses = responses
|
|
|
|
@completions = 0
|
2024-07-02 11:51:59 -04:00
|
|
|
@dialect = nil
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2024-01-04 07:53:47 -05:00
|
|
|
def normalize_model_params(model_params)
|
|
|
|
# max_tokens, temperature, stop_sequences are already supported
|
|
|
|
model_params
|
|
|
|
end
|
|
|
|
|
2024-12-06 09:13:47 -05:00
|
|
|
attr_reader :responses, :completions, :dialect, :model_params
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2024-07-02 11:51:59 -04:00
|
|
|
def prompt_messages
|
|
|
|
dialect.prompt.messages
|
|
|
|
end
|
|
|
|
|
2024-10-23 01:49:56 -04:00
|
|
|
def perform_completion!(
|
|
|
|
dialect,
|
|
|
|
_user,
|
2024-12-06 09:13:47 -05:00
|
|
|
model_params,
|
2024-10-23 01:49:56 -04:00
|
|
|
feature_name: nil,
|
2024-11-13 14:58:24 -05:00
|
|
|
feature_context: nil,
|
|
|
|
partial_tool_calls: false
|
2024-10-23 01:49:56 -04:00
|
|
|
)
|
2024-07-02 11:51:59 -04:00
|
|
|
@dialect = dialect
|
2024-12-06 09:13:47 -05:00
|
|
|
@model_params = model_params
|
2023-11-23 10:58:54 -05:00
|
|
|
response = responses[completions]
|
|
|
|
if response.nil?
|
|
|
|
raise CANNED_RESPONSE_ERROR,
|
|
|
|
"The number of completions you requested exceed the number of canned responses"
|
|
|
|
end
|
|
|
|
|
2024-07-24 15:29:47 -04:00
|
|
|
raise response if response.is_a?(StandardError)
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
@completions += 1
|
|
|
|
if block_given?
|
|
|
|
cancelled = false
|
|
|
|
cancel_fn = lambda { cancelled = true }
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
# We buffer and return tool invocations in one go.
|
2024-11-11 16:14:30 -05:00
|
|
|
as_array = response.is_a?(Array) ? response : [response]
|
|
|
|
as_array.each do |response|
|
|
|
|
if is_tool?(response)
|
|
|
|
yield(response, cancel_fn)
|
|
|
|
else
|
|
|
|
response.each_char do |char|
|
|
|
|
break if cancelled
|
|
|
|
yield(char, cancel_fn)
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
2024-11-11 16:14:30 -05:00
|
|
|
|
|
|
|
response = response.first if response.is_a?(Array) && response.length == 1
|
|
|
|
response
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
|
|
|
|
private
|
|
|
|
|
|
|
|
def is_tool?(response)
|
2024-11-11 16:14:30 -05:00
|
|
|
response.is_a?(DiscourseAi::Completions::ToolCall)
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|