2023-11-23 10:58:54 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
# A facade that abstracts multiple LLMs behind a single interface.
|
|
|
|
#
|
|
|
|
# Internally, it consists of the combination of a dialect and an endpoint.
|
|
|
|
# After recieving a prompt using our generic format, it translates it to
|
|
|
|
# the target model and routes the completion request through the correct gateway.
|
|
|
|
#
|
|
|
|
# Use the .proxy method to instantiate an object.
|
|
|
|
# It chooses the best dialect and endpoint for the model you want to interact with.
|
|
|
|
#
|
|
|
|
# Tests of modules that perform LLM calls can use .with_prepared_responses to return canned responses
|
|
|
|
# instead of relying on WebMock stubs like we did in the past.
|
|
|
|
#
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
2023-11-28 23:17:46 -05:00
|
|
|
class Llm
|
2023-11-23 10:58:54 -05:00
|
|
|
UNKNOWN_MODEL = Class.new(StandardError)
|
|
|
|
|
|
|
|
def self.with_prepared_responses(responses)
|
|
|
|
@canned_response = DiscourseAi::Completions::Endpoints::CannedResponse.new(responses)
|
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
yield(@canned_response)
|
|
|
|
ensure
|
|
|
|
# Don't leak prepared response if there's an exception.
|
|
|
|
@canned_response = nil
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def self.proxy(model_name)
|
2023-12-18 16:06:01 -05:00
|
|
|
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
return new(dialect_klass, @canned_response, model_name) if @canned_response
|
2023-11-23 10:58:54 -05:00
|
|
|
|
|
|
|
gateway =
|
|
|
|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
|
|
|
|
model_name,
|
2023-12-18 16:06:01 -05:00
|
|
|
dialect_klass.tokenizer,
|
2023-11-23 10:58:54 -05:00
|
|
|
)
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
new(dialect_klass, gateway, model_name)
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def initialize(dialect_klass, gateway, model_name)
|
|
|
|
@dialect_klass = dialect_klass
|
2023-11-23 10:58:54 -05:00
|
|
|
@gateway = gateway
|
|
|
|
@model_name = model_name
|
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
delegate :tokenizer, to: :dialect_klass
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
2023-11-23 10:58:54 -05:00
|
|
|
# @param user { User } - User requesting the summary.
|
|
|
|
#
|
|
|
|
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
|
|
|
#
|
|
|
|
# @returns { String } - Completion result.
|
2023-12-18 16:06:01 -05:00
|
|
|
#
|
|
|
|
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
|
|
|
|
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
|
|
|
|
#
|
|
|
|
# <function_calls>
|
|
|
|
# <invoke>
|
|
|
|
# <tool_name>get_weather</tool_name>
|
|
|
|
# <tool_id>get_weather</tool_id>
|
|
|
|
# <parameters>
|
|
|
|
# <location>Sydney</location>
|
|
|
|
# <unit>c</unit>
|
|
|
|
# </parameters>
|
|
|
|
# </invoke>
|
|
|
|
# </function_calls>
|
|
|
|
#
|
2024-01-04 07:53:47 -05:00
|
|
|
def generate(
|
2024-01-12 12:36:44 -05:00
|
|
|
prompt,
|
2024-01-04 07:53:47 -05:00
|
|
|
temperature: nil,
|
|
|
|
max_tokens: nil,
|
|
|
|
stop_sequences: nil,
|
|
|
|
user:,
|
|
|
|
&partial_read_blk
|
|
|
|
)
|
|
|
|
model_params = {
|
|
|
|
temperature: temperature,
|
|
|
|
max_tokens: max_tokens,
|
|
|
|
stop_sequences: stop_sequences,
|
|
|
|
}
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2024-01-15 02:51:14 -05:00
|
|
|
if prompt.is_a?(String)
|
|
|
|
prompt =
|
|
|
|
DiscourseAi::Completions::Prompt.new(
|
|
|
|
"You are a helpful bot",
|
|
|
|
messages: [{ type: :user, content: prompt }],
|
|
|
|
)
|
|
|
|
elsif prompt.is_a?(Array)
|
|
|
|
prompt = DiscourseAi::Completions::Prompt.new(messages: prompt)
|
|
|
|
end
|
|
|
|
|
2024-01-15 23:21:58 -05:00
|
|
|
if !prompt.is_a?(DiscourseAi::Completions::Prompt)
|
|
|
|
raise ArgumentError, "Prompt must be either a string, array, of Prompt object"
|
|
|
|
end
|
|
|
|
|
2024-01-04 07:53:47 -05:00
|
|
|
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
|
2023-12-18 16:06:01 -05:00
|
|
|
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
def max_prompt_tokens
|
2024-01-12 12:36:44 -05:00
|
|
|
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
2024-01-04 08:44:07 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
attr_reader :model_name
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
private
|
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
attr_reader :dialect_klass, :gateway
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|