From e0bf6adb5b7370a854568ba2e0cfa1908af85d04 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 18 Dec 2023 18:06:01 -0300 Subject: [PATCH] DEV: Tool support for the LLM service. (#366) This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response. It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect. Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M --- lib/completions/dialects/chat_gpt.rb | 94 +++++++-- lib/completions/dialects/claude.rb | 60 +++++- lib/completions/dialects/dialect.rb | 160 +++++++++++++++ lib/completions/dialects/gemini.rb | 77 +++++++- lib/completions/dialects/llama2_classic.rb | 67 +++++-- lib/completions/dialects/orca_style.rb | 61 ++++-- lib/completions/endpoints/anthropic.rb | 2 +- lib/completions/endpoints/aws_bedrock.rb | 2 +- lib/completions/endpoints/base.rb | 83 +++++++- lib/completions/endpoints/canned_response.rb | 17 +- lib/completions/endpoints/gemini.rb | 72 +++++-- lib/completions/endpoints/hugging_face.rb | 5 +- lib/completions/endpoints/open_ai.rb | 56 +++++- lib/completions/llm.rb | 72 +++++-- .../lib/completions/dialects/chat_gpt_spec.rb | 66 ++++++- spec/lib/completions/dialects/claude_spec.rb | 125 +++++++++++- spec/lib/completions/dialects/gemini_spec.rb | 94 ++++++++- .../dialects/llama2_classic_spec.rb | 143 +++++++++++++- .../completions/dialects/orca_style_spec.rb | 183 +++++++++++++++--- .../completions/endpoints/anthropic_spec.rb | 28 ++- .../completions/endpoints/aws_bedrock_spec.rb | 57 ++---- .../endpoints/endpoint_examples.rb | 183 ++++++++++++++---- spec/lib/completions/endpoints/gemini_spec.rb | 72 +++++-- .../endpoints/hugging_face_spec.rb | 49 +++-- .../lib/completions/endpoints/open_ai_spec.rb | 77 ++++++-- spec/lib/completions/llm_spec.rb | 2 +- .../ai_helper/assistant_controller_spec.rb | 2 +- 27 files changed, 1625 insertions(+), 284 deletions(-) create mode 100644 lib/completions/dialects/dialect.rb diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 82595225..dda3e8a2 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -3,31 +3,99 @@ module DiscourseAi module Completions module Dialects - class ChatGpt - def self.can_translate?(model_name) - %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name) + class ChatGpt < Dialect + class << self + def can_translate?(model_name) + %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end end - def translate(generic_prompt) + def translate open_ai_prompt = [ - { - role: "system", - content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"), - }, + { role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") }, ] - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| open_ai_prompt << { role: "user", content: example_pair.first } open_ai_prompt << { role: "assistant", content: example_pair.second } end end - open_ai_prompt << { role: "user", content: generic_prompt[:input] } + open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context] + + open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input] + + open_ai_prompt end - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer + def tools + return if prompt[:tools].blank? + + prompt[:tools].map { |t| { type: "function", tool: t } } + end + + def conversation_context + return [] if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context.reverse.map do |context| + translated = context.slice(:content) + translated[:role] = context[:type] + + if context[:name] + if translated[:role] == "tool" + translated[:tool_call_id] = context[:name] + else + translated[:name] = context[:name] + end + end + + translated + end + end + + def max_prompt_tokens + # provide a buffer of 120 tokens - our function counting is not + # 100% accurate and getting numbers to align exactly is very hard + buffer = (opts[:max_tokens_to_sample] || 2500) + 50 + + if tools.present? + # note this is about 100 tokens over, OpenAI have a more optimal representation + @function_size ||= self.class.tokenizer.size(tools.to_json.to_s) + buffer += @function_size + end + + model_max_tokens - buffer + end + + private + + def per_message_overhead + # open ai defines about 4 tokens per message of overhead + 4 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) + end + + def model_max_tokens + case model_name + when "gpt-3.5-turbo", "gpt-3.5-turbo-16k" + 16_384 + when "gpt-4" + 8192 + when "gpt-4-32k" + 32_768 + else + 8192 + end end end end diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 07438985..f10c4c49 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -3,25 +3,65 @@ module DiscourseAi module Completions module Dialects - class Claude - def self.can_translate?(model_name) - %w[claude-instant-1 claude-2].include?(model_name) + class Claude < Dialect + class << self + def can_translate?(model_name) + %w[claude-instant-1 claude-2].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::AnthropicTokenizer + end end - def translate(generic_prompt) - claude_prompt = +"Human: #{generic_prompt[:insts]}\n" + def translate + claude_prompt = +"Human: #{prompt[:insts]}\n" - claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples] + claude_prompt << build_tools_prompt if prompt[:tools] - claude_prompt << "#{generic_prompt[:input]}\n" + claude_prompt << build_examples(prompt[:examples]) if prompt[:examples] - claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts] + claude_prompt << conversation_context if prompt[:conversation_context] + + claude_prompt << "#{prompt[:input]}\n" + + claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts] claude_prompt << "Assistant:\n" end - def tokenizer - DiscourseAi::Tokenizer::AnthropicTokenizer + def max_prompt_tokens + 50_000 + end + + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + memo << (context[:type] == "user" ? "Human:" : "Assistant:") + + if context[:type] == "tool" + memo << <<~TEXT + + + + #{context[:name]} + + #{context[:content]} + + + + TEXT + else + memo << " " << context[:content] << "\n" + end + + memo + end end private diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb new file mode 100644 index 00000000..5e6d6b97 --- /dev/null +++ b/lib/completions/dialects/dialect.rb @@ -0,0 +1,160 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Dialect + class << self + def can_translate?(_model_name) + raise NotImplemented + end + + def dialect_for(model_name) + dialects = [ + DiscourseAi::Completions::Dialects::Claude, + DiscourseAi::Completions::Dialects::Llama2Classic, + DiscourseAi::Completions::Dialects::ChatGpt, + DiscourseAi::Completions::Dialects::OrcaStyle, + DiscourseAi::Completions::Dialects::Gemini, + ] + dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d| + d.can_translate?(model_name) + end + end + + def tokenizer + raise NotImplemented + end + end + + def initialize(generic_prompt, model_name, opts: {}) + @prompt = generic_prompt + @model_name = model_name + @opts = opts + end + + def translate + raise NotImplemented + end + + def tools + tools = +"" + + prompt[:tools].each do |function| + parameters = +"" + if function[:parameters].present? + function[:parameters].each do |parameter| + parameters << <<~PARAMETER + + #{parameter[:name]} + #{parameter[:type]} + #{parameter[:description]} + #{parameter[:required]} + PARAMETER + if parameter[:enum] + parameters << "#{parameter[:enum].join(",")}\n" + end + parameters << "\n" + end + end + + tools << <<~TOOLS + + #{function[:name]} + #{function[:description]} + + #{parameters} + + TOOLS + end + + tools + end + + def conversation_context + raise NotImplemented + end + + def max_prompt_tokens + raise NotImplemented + end + + private + + attr_reader :prompt, :model_name, :opts + + def trim_context(conversation_context) + prompt_limit = max_prompt_tokens + current_token_count = calculate_token_count_without_context + + conversation_context.reduce([]) do |memo, context| + break(memo) if current_token_count >= prompt_limit + + dupped_context = context.dup + + message_tokens = calculate_message_token(dupped_context) + + # Trimming content to make sure we respect token limit. + while dupped_context[:content].present? && + message_tokens + current_token_count + per_message_overhead > prompt_limit + dupped_context[:content] = dupped_context[:content][0..-100] || "" + message_tokens = calculate_message_token(dupped_context) + end + + next(memo) if dupped_context[:content].blank? + + current_token_count += message_tokens + per_message_overhead + + memo << dupped_context + end + end + + def calculate_token_count_without_context + tokenizer = self.class.tokenizer + + examples_count = + prompt[:examples].to_a.sum do |pair| + tokenizer.size(pair.join) + (per_message_overhead * 2) + end + input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead + + examples_count + input_count + + prompt + .except(:conversation_context, :tools, :examples, :input) + .sum { |_, v| tokenizer.size(v) + per_message_overhead } + end + + def per_message_overhead + 0 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s) + end + + def build_tools_prompt + return "" if prompt[:tools].blank? + + <<~TEXT + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{tools} + TEXT + end + end + end + end +end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 49fa716e..b6b938bf 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -3,34 +3,91 @@ module DiscourseAi module Completions module Dialects - class Gemini - def self.can_translate?(model_name) - %w[gemini-pro].include?(model_name) + class Gemini < Dialect + class << self + def can_translate?(model_name) + %w[gemini-pro].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer + end end - def translate(generic_prompt) + def translate gemini_prompt = [ { role: "user", parts: { - text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"), + text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"), }, }, { role: "model", parts: { text: "Ok." } }, ] - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| gemini_prompt << { role: "user", parts: { text: example_pair.first } } gemini_prompt << { role: "model", parts: { text: example_pair.second } } end end - gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } } + gemini_prompt.concat!(conversation_context) if prompt[:conversation_context] + + gemini_prompt << { role: "user", parts: { text: prompt[:input] } } end - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer + def tools + return if prompt[:tools].blank? + + translated_tools = + prompt[:tools].map do |t| + required_fields = [] + tool = t.dup + + tool[:parameters] = t[:parameters].map do |p| + required_fields << p[:name] if p[:required] + + p.except(:required) + end + + tool.merge(required: required_fields) + end + + [{ function_declarations: translated_tools }] + end + + def conversation_context + return [] if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context.reverse.map do |context| + translated = {} + translated[:role] = (context[:type] == "user" ? "user" : "model") + + part = {} + + if context[:type] == "tool" + part["functionResponse"] = { name: context[:name], content: context[:content] } + else + part[:text] = context[:content] + end + + translated[:parts] = [part] + + translated + end + end + + def max_prompt_tokens + 16_384 # 50% of model tokens + end + + protected + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index 542e5f57..0470a61c 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -3,27 +3,72 @@ module DiscourseAi module Completions module Dialects - class Llama2Classic - def self.can_translate?(model_name) - %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) + class Llama2Classic < Dialect + class << self + def can_translate?(model_name) + %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::Llama2Tokenizer + end end - def translate(generic_prompt) - llama2_prompt = - +"[INST]<>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<>[/INST]\n" + def translate + llama2_prompt = +<<~TEXT + [INST] + <> + #{prompt[:insts]} + #{build_tools_prompt}#{prompt[:post_insts]} + <> + [/INST] + TEXT - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| llama2_prompt << "[INST]#{example_pair.first}[/INST]\n" llama2_prompt << "#{example_pair.second}\n" end end - llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n" + llama2_prompt << conversation_context if prompt[:conversation_context].present? + + llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n" end - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + if context[:type] == "tool" + memo << <<~TEXT + [INST] + + + #{context[:name]} + + #{context[:content]} + + + + [/INST] + TEXT + elsif context[:type] == "assistant" + memo << "[INST]" << context[:content] << "[/INST]\n" + else + memo << context[:content] << "\n" + end + + memo + end + end + + def max_prompt_tokens + SiteSetting.ai_hugging_face_token_limit end end end diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index 3aa11609..fd76f3b5 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -3,29 +3,68 @@ module DiscourseAi module Completions module Dialects - class OrcaStyle - def self.can_translate?(model_name) - %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) + class OrcaStyle < Dialect + class << self + def can_translate?(model_name) + %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::Llama2Tokenizer + end end - def translate(generic_prompt) - orca_style_prompt = - +"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n" + def translate + orca_style_prompt = +<<~TEXT + ### System: + #{prompt[:insts]} + #{build_tools_prompt}#{prompt[:post_insts]} + TEXT - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| orca_style_prompt << "### User:\n#{example_pair.first}\n" orca_style_prompt << "### Assistant:\n#{example_pair.second}\n" end end - orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n" + orca_style_prompt << "### User:\n#{prompt[:input]}\n" orca_style_prompt << "### Assistant:\n" end - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + memo << (context[:type] == "user" ? "### User:" : "### Assistant:") + + if context[:type] == "tool" + memo << <<~TEXT + + + + #{context[:name]} + + #{context[:content]} + + + + TEXT + else + memo << " " << context[:content] << "\n" + end + + memo + end + end + + def max_prompt_tokens + SiteSetting.ai_hugging_face_token_limit end end end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 5216d4e7..3846d7e4 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -22,7 +22,7 @@ module DiscourseAi @uri ||= URI("https://api.anthropic.com/v1/complete") end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options .merge(model_params) .merge(prompt: prompt) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 0e8cf68f..902a375c 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -37,7 +37,7 @@ module DiscourseAi URI(api_url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options.merge(prompt: prompt).merge(model_params) end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 61846e23..ef92a162 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -30,9 +30,11 @@ module DiscourseAi @tokenizer = tokenizer end - def perform_completion!(prompt, user, model_params = {}) + def perform_completion!(dialect, user, model_params = {}) @streaming_mode = block_given? + prompt = dialect.translate + Net::HTTP.start( model_uri.host, model_uri.port, @@ -43,7 +45,10 @@ module DiscourseAi ) do |http| response_data = +"" response_raw = +"" - request_body = prepare_payload(prompt, model_params).to_json + + # Needed to response token calculations. Cannot rely on response_data due to function buffering. + partials_raw = +"" + request_body = prepare_payload(prompt, model_params, dialect).to_json request = prepare_request(request_body) @@ -66,6 +71,15 @@ module DiscourseAi if !@streaming_mode response_raw = response.read_body response_data = extract_completion_from(response_raw) + partials_raw = response_data.to_s + + if has_tool?("", response_data) + function_buffer = build_buffer # Nokogiri document + function_buffer = add_to_buffer(function_buffer, "", response_data) + + response_data = +function_buffer.at("function_calls").to_s + response_data << "\n" + end return response_data end @@ -75,6 +89,7 @@ module DiscourseAi cancel = lambda { cancelled = true } leftover = "" + function_buffer = build_buffer # Nokogiri document response.read_body do |chunk| if cancelled @@ -85,6 +100,12 @@ module DiscourseAi decoded_chunk = decode(chunk) response_raw << decoded_chunk + # Buffering for extremely slow streaming. + if (leftover + decoded_chunk).length < "data: [DONE]".length + leftover += decoded_chunk + next + end + partials_from(leftover + decoded_chunk).each do |raw_partial| next if cancelled next if raw_partial.blank? @@ -93,11 +114,27 @@ module DiscourseAi partial = extract_completion_from(raw_partial) next if partial.nil? leftover = "" - response_data << partial - yield partial, cancel if partial + if has_tool?(response_data, partial) + function_buffer = add_to_buffer(function_buffer, response_data, partial) + + if buffering_finished?(dialect.tools, function_buffer) + invocation = +function_buffer.at("function_calls").to_s + invocation << "\n" + + partials_raw << partial.to_s + response_data << invocation + + yield invocation, cancel + end + else + partials_raw << partial + response_data << partial + + yield partial, cancel if partial + end rescue JSON::ParserError - leftover = raw_partial + leftover += decoded_chunk end end end @@ -109,7 +146,7 @@ module DiscourseAi ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(response_data) + log.response_tokens = tokenizer.size(partials_raw) log.save! if Rails.env.development? @@ -165,6 +202,40 @@ module DiscourseAi def extract_prompt_for_tokenizer(prompt) prompt end + + def build_buffer + Nokogiri::HTML5.fragment(<<~TEXT) + + + + + + + + TEXT + end + + def has_tool?(response, partial) + (response + partial).include?("") + end + + def add_to_buffer(function_buffer, response_data, partial) + new_buffer = Nokogiri::HTML5.fragment(response_data + partial) + if tool_name = new_buffer.at("tool_name").text + if new_buffer.at("tool_id").nil? + tool_id_node = + Nokogiri::HTML5::DocumentFragment.parse("\n#{tool_name}") + + new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node) + end + end + + new_buffer + end + + def buffering_finished?(_available_functions, buffer) + buffer.to_s.include?("") + end end end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index 83300fcc..56d2b913 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -31,9 +31,14 @@ module DiscourseAi cancelled = false cancel_fn = lambda { cancelled = true } - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + # We buffer and return tool invocations in one go. + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end else response @@ -43,6 +48,12 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + private + + def is_tool?(response) + Nokogiri::HTML5.fragment(response).at("function_calls").present? + end end end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 7738085c..8fa8a2d3 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -25,8 +25,11 @@ module DiscourseAi URI(url) end - def prepare_payload(prompt, model_params) - default_options.merge(model_params).merge(contents: prompt) + def prepare_payload(prompt, model_params, dialect) + default_options + .merge(model_params) + .merge(contents: prompt) + .tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? } end def prepare_request(payload) @@ -36,25 +39,72 @@ module DiscourseAi end def extract_completion_from(response_raw) - if @streaming_mode - parsed = response_raw - else - parsed = JSON.parse(response_raw, symbolize_names: true) - end + parsed = JSON.parse(response_raw, symbolize_names: true) - completion = dig_text(parsed).to_s + response_h = parsed.dig(:candidates, 0, :content, :parts, 0) + + has_function_call = response_h.dig(:functionCall).present? + has_function_call ? response_h[:functionCall] : response_h.dig(:text) end def partials_from(decoded_chunk) - JSON.parse(decoded_chunk, symbolize_names: true) + decoded_chunk + .split("\n") + .map do |line| + if line == "," + nil + elsif line.starts_with?("[") + line[1..-1] + elsif line.ends_with?("]") + line[0..-1] + else + line + end + end + .compact_blank end def extract_prompt_for_tokenizer(prompt) prompt.to_s end - def dig_text(response) - response.dig(:candidates, 0, :content, :parts, 0, :text) + def has_tool?(_response_data, partial) + partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + end + + def add_to_buffer(function_buffer, _response_data, partial) + if partial[:name].present? + function_buffer.at("tool_name").content = partial[:name] + function_buffer.at("tool_id").content = partial[:name] + end + + if partial[:args] + argument_fragments = + partial[:args].reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + function_buffer.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + end + + function_buffer + end + + def buffering_finished?(available_functions, buffer) + tool_name = buffer.at("tool_name")&.text + return false if tool_name.blank? + + signature = + available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name } + + signature[:parameters].reduce(true) do |memo, param| + param_present = buffer.at(param[:name]).present? + next(memo) if param_present || !signature[:required].include?(param[:name]) + + memo && param_present + end end end end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index d5d10cea..819cec47 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -11,7 +11,7 @@ module DiscourseAi end def default_options - { parameters: { repetition_penalty: 1.1, temperature: 0.7 } } + { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } end def provider_id @@ -24,7 +24,7 @@ module DiscourseAi URI(SiteSetting.ai_hugging_face_api_url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options .merge(inputs: prompt) .tap do |payload| @@ -33,7 +33,6 @@ module DiscourseAi token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) - payload[:parameters][:return_full_text] = false payload[:stream] = true if @streaming_mode end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 65b01314..b0664cd0 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -37,11 +37,14 @@ module DiscourseAi URI(url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, dialect) default_options .merge(model_params) .merge(messages: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + .tap do |payload| + payload[:stream] = true if @streaming_mode + payload[:tools] = dialect.tools if dialect.tools.present? + end end def prepare_request(payload) @@ -62,15 +65,12 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - ( - if @streaming_mode - parsed.dig(:choices, 0, :delta, :content) - else - parsed.dig(:choices, 0, :message, :content) - end - ).to_s + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + has_function_call = response_h.dig(:tool_calls).present? + has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content) end def partials_from(decoded_chunk) @@ -86,6 +86,42 @@ module DiscourseAi def extract_prompt_for_tokenizer(prompt) prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end + + def has_tool?(_response_data, partial) + partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + end + + def add_to_buffer(function_buffer, _response_data, partial) + function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? + function_buffer.at("tool_id").content = partial[:id] if partial[:id].present? + + if partial[:arguments] + argument_fragments = + partial[:arguments].reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + function_buffer.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + end + + function_buffer + end + + def buffering_finished?(available_functions, buffer) + tool_name = buffer.at("tool_name")&.text + return false if tool_name.blank? + + signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool] + + signature[:parameters].reduce(true) do |memo, param| + param_present = buffer.at(param[:name]).present? + next(memo) if param_present && !param[:required] + + memo && param_present + end + end end end end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index f2cbc0e2..16e1c96b 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -24,35 +24,26 @@ module DiscourseAi end def self.proxy(model_name) - dialects = [ - DiscourseAi::Completions::Dialects::Claude, - DiscourseAi::Completions::Dialects::Llama2Classic, - DiscourseAi::Completions::Dialects::ChatGpt, - DiscourseAi::Completions::Dialects::OrcaStyle, - DiscourseAi::Completions::Dialects::Gemini, - ] + dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name) - dialect = - dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new - - return new(dialect, @canned_response, model_name) if @canned_response + return new(dialect_klass, @canned_response, model_name) if @canned_response gateway = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new( model_name, - dialect.tokenizer, + dialect_klass.tokenizer, ) - new(dialect, gateway, model_name) + new(dialect_klass, gateway, model_name) end - def initialize(dialect, gateway, model_name) - @dialect = dialect + def initialize(dialect_klass, gateway, model_name) + @dialect_klass = dialect_klass @gateway = gateway @model_name = model_name end - delegate :tokenizer, to: :dialect + delegate :tokenizer, to: :dialect_klass # @param generic_prompt { Hash } - Prompt using our generic format. # We use the following keys from the hash: @@ -60,23 +51,64 @@ module DiscourseAi # - input: String containing user input # - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]]. # - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt. + # - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model. + # We translate the array in reverse order, meaning the first element would be the most recent message in the conversation. + # Example: + # + # [ + # { type: "user", name: "user1", content: "This is a new message by a user" }, + # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + # { type: "tool", name: "tool_id", content: "I'm a tool result" }, + # ] + # + # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example: + # + # { + # name: "get_weather", + # description: "Get the weather in a city", + # parameters: [ + # { name: "location", type: "string", description: "the city name", required: true }, + # { + # name: "unit", + # type: "string", + # description: "the unit of measurement celcius c or fahrenheit f", + # enum: %w[c f], + # required: true, + # }, + # ], + # } # # @param user { User } - User requesting the summary. # # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # # @returns { String } - Completion result. + # + # When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool, + # even if you passed a partial_read_blk block. Invocations are strings that look like this: + # + # + # + # get_weather + # get_weather + # + # Sydney + # c + # + # + # + # def completion!(generic_prompt, user, &partial_read_blk) - prompt = dialect.translate(generic_prompt) - model_params = generic_prompt.dig(:params, model_name) || {} - gateway.perform_completion!(prompt, user, model_params, &partial_read_blk) + dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params) + + gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end private - attr_reader :dialect, :gateway, :model_name + attr_reader :dialect_klass, :gateway, :model_name end end end diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index 27b56b49..2792dbf6 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "gpt-4") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do TEXT post_insts: "Please put the translation between tags and separate each title with a comma.", + tools: [tool], } end @@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do { role: "user", content: prompt[:input] }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*open_ai_version) end @@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do { role: "user", content: prompt[:input] }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*open_ai_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context).to eq( + [ + { role: "tool", content: context.last[:content], tool_call_id: context.last[:name] }, + { role: "assistant", content: context.second[:content] }, + { role: "user", content: context.first[:content], name: context.first[:name] }, + ], + ) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1000 + + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.last[:content].length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "returns a list of available tools" do + open_ai_tool_f = { type: "function", tool: tool } + + expect(subject.tools).to contain_exactly(open_ai_tool_f) + end + end end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index d26dd570..1107814b 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Claude do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "claude-2") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(anthropic_version) end @@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(anthropic_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + anthropic_version = <<~TEXT + Human: #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:input]} + #{prompt[:post_insts]} + Assistant: + TEXT + + translated = dialect.translate expect(translated).to eq(anthropic_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + Assistant: + + + tool_id + + #{context.last[:content]} + + + + Assistant: #{context.second[:content]} + Human: #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 10_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates tools to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 84aec2d3..0a5266b7 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Gemini do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "gemini-pro") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do TEXT post_insts: "Please put the translation between tags and separate each title with a comma.", + tools: [tool], } end @@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { role: "user", parts: { text: prompt[:input] } }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(gemini_version) end @@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { role: "user", parts: { text: prompt[:input] } }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*gemini_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context).to eq( + [ + { + role: "model", + parts: [ + { + "functionResponse" => { + name: context.last[:name], + content: context.last[:content], + }, + }, + ], + }, + { role: "model", parts: [{ text: context.second[:content] }] }, + { role: "user", parts: [{ text: context.first[:content] }] }, + ], + ) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1000 + + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.last.dig(:parts, 0, :text).length).to be < + context.last[:content].length + end + end + + describe "#tools" do + it "returns a list of available tools" do + gemini_tools = { + function_declarations: [ + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name" }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + ], + required: %w[location unit], + }, + ], + } + + expect(subject.tools).to contain_exactly(gemini_tools) + end + end end diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb index 2b1d93a2..81d088e6 100644 --- a/spec/lib/completions/dialects/llama2_classic_spec.rb +++ b/spec/lib/completions/dialects/llama2_classic_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do describe "#translate" do it "translates a prompt written in our generic format to the Llama2 format" do llama2_classic_version = <<~TEXT - [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST] + [INST] + <> + #{prompt[:insts]} + #{prompt[:post_insts]} + <> + [/INST] [INST]#{prompt[:input]}[/INST] TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(llama2_classic_version) end @@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do ] llama2_classic_version = <<~TEXT - [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST] + [INST] + <> + #{prompt[:insts]} + #{prompt[:post_insts]} + <> + [/INST] [INST]#{prompt[:examples][0][0]}[/INST] #{prompt[:examples][0][1]} [INST]#{prompt[:input]}[/INST] TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(llama2_classic_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + llama2_classic_version = <<~TEXT + [INST] + <> + #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:post_insts]} + <> + [/INST] + [INST]#{prompt[:input]}[/INST] + TEXT + + translated = dialect.translate expect(translated).to eq(llama2_classic_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + [INST] + + + tool_id + + #{context.last[:content]} + + + + [/INST] + [INST]#{context.second[:content]}[/INST] + #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates functions to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb index 411a84a8..d27dc9d3 100644 --- a/spec/lib/completions/dialects/orca_style_spec.rb +++ b/spec/lib/completions/dialects/orca_style_spec.rb @@ -1,44 +1,62 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "StableBeluga2") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end + + let(:prompt) do + { + insts: <<~TEXT, + I want you to act as a title generator for written pieces. I will provide you with a text, + and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words, + and ensure that the meaning is maintained. Replies will utilize the language type of the topic. + TEXT + input: <<~TEXT, + Here is the text, inside XML tags: + + To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, + discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer + defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry. + + Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires, + a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and + slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he + dies so that a scene may be repeated. + + TEXT + post_insts: + "Please put the translation between tags and separate each title with a comma.", + } + end describe "#translate" do - let(:prompt) do - { - insts: <<~TEXT, - I want you to act as a title generator for written pieces. I will provide you with a text, - and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words, - and ensure that the meaning is maintained. Replies will utilize the language type of the topic. - TEXT - input: <<~TEXT, - Here is the text, inside XML tags: - - To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, - discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer - defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry. - - Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires, - a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and - slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he - dies so that a scene may be repeated. - - TEXT - post_insts: - "Please put the translation between tags and separate each title with a comma.", - } - end - it "translates a prompt written in our generic format to the Open AI format" do orca_style_version = <<~TEXT ### System: - #{[prompt[:insts], prompt[:post_insts]].join("\n")} + #{prompt[:insts]} + #{prompt[:post_insts]} ### User: #{prompt[:input]} ### Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(orca_style_version) end @@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do orca_style_version = <<~TEXT ### System: - #{[prompt[:insts], prompt[:post_insts]].join("\n")} + #{prompt[:insts]} + #{prompt[:post_insts]} ### User: #{prompt[:examples][0][0]} ### Assistant: @@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do ### Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(orca_style_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + orca_style_version = <<~TEXT + ### System: + #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:post_insts]} + ### User: + #{prompt[:input]} + ### Assistant: + TEXT + + translated = dialect.translate expect(translated).to eq(orca_style_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + ### Assistant: + + + tool_id + + #{context.last[:content]} + + + + ### Assistant: #{context.second[:content]} + ### User: #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates tools to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index d0309e2f..0a57ad29 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) } let(:model_name) { "claude-2" } - let(:prompt) { "Human: write 3 words\n\n" } + let(:generic_prompt) { { insts: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } @@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt).to_json) + .with(body: request_body) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) @@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end - chunks = chunks.join("\n\n") + chunks = chunks.join("\n\n").split("") WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 5c0cb8cc..2c866898 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do let(:model_name) { "claude-2" } let(:bedrock_name) { "claude-v2" } - let(:prompt) { "Human: write 3 words\n\n" } + let(:generic_prompt) { { insts: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json } + let(:stream_request_body) { request_body } before do SiteSetting.ai_bedrock_access_key_id = "123456" @@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do SiteSetting.ai_bedrock_region = "us-east-1" end - # Copied from https://github.com/bblimke/webmock/issues/629 - # Workaround for stubbing a streamed response - before do - mocked_http = - Class.new(Net::HTTP) do - def request(*) - super do |response| - response.instance_eval do - def read_body(*, &block) - if block_given? - @body.each(&block) - else - super - end - end - end - - yield response if block_given? - - response - end - end - end - - @original_net_http = Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, mocked_http) - end - - after do - Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, @original_net_http) - end - def response(content) { completion: content, @@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request( :post, @@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do encoder.encode(message) end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) @@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb index 6ca86070..0de7ef73 100644 --- a/spec/lib/completions/endpoints/endpoint_examples.rb +++ b/spec/lib/completions/endpoints/endpoint_examples.rb @@ -1,70 +1,177 @@ # frozen_string_literal: true RSpec.shared_examples "an endpoint that can communicate with a completion service" do + # Copied from https://github.com/bblimke/webmock/issues/629 + # Workaround for stubbing a streamed response + before do + mocked_http = + Class.new(Net::HTTP) do + def request(*) + super do |response| + response.instance_eval do + def read_body(*, &block) + if block_given? + @body.each(&block) + else + super + end + end + end + + yield response if block_given? + + response + end + end + end + + @original_net_http = Net.send(:remove_const, :HTTP) + Net.send(:const_set, :HTTP, mocked_http) + end + + after do + Net.send(:remove_const, :HTTP) + Net.send(:const_set, :HTTP, @original_net_http) + end + describe "#perform_completion!" do fab!(:user) { Fabricate(:user) } - let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end + + let(:invocation) { <<~TEXT } + + + get_weather + get_weather + + Sydney + c + + + + TEXT context "when using regular mode" do - before { stub_response(prompt, response_text) } + context "with simple prompts" do + let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } - it "can complete a trivial prompt" do - completion_response = model.perform_completion!(prompt, user) + before { stub_response(prompt, response_text) } - expect(completion_response).to eq(response_text) + it "can complete a trivial prompt" do + completion_response = model.perform_completion!(dialect, user) + + expect(completion_response).to eq(response_text) + end + + it "creates an audit log for the request" do + model.perform_completion!(dialect, user) + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + response_body = response(response_text).to_json + + expect(log.provider_id).to eq(model.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to eq(request_body) + expect(log.raw_response_payload).to eq(response_body) + expect(log.request_tokens).to eq(model.prompt_size(prompt)) + expect(log.response_tokens).to eq(model.tokenizer.size(response_text)) + end end - it "creates an audit log for the request" do - model.perform_completion!(prompt, user) + context "with functions" do + let(:generic_prompt) do + { + insts: "You can tell me the weather", + input: "Return the weather in Sydney", + tools: [tool], + } + end - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first + before { stub_response(prompt, tool_call, tool_call: true) } - response_body = response(response_text).to_json + it "returns a function invocation" do + completion_response = model.perform_completion!(dialect, user) - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(request_body) - expect(log.raw_response_payload).to eq(response_body) - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(response_text)) + expect(completion_response).to eq(invocation) + end end end context "when using stream mode" do - let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } + context "with simple prompts" do + let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } - before { stub_streamed_response(prompt, deltas) } + before { stub_streamed_response(prompt, deltas) } - it "can complete a trivial prompt" do - completion_response = +"" + it "can complete a trivial prompt" do + completion_response = +"" - model.perform_completion!(prompt, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + model.perform_completion!(dialect, user) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(completion_response).to eq(deltas[0...-1].join) end - expect(completion_response).to eq(deltas[0...-1].join) + it "creates an audit log and updates is on each read." do + completion_response = +"" + + model.perform_completion!(dialect, user) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + expect(log.provider_id).to eq(model.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to eq(stream_request_body) + expect(log.raw_response_payload).to be_present + expect(log.request_tokens).to eq(model.prompt_size(prompt)) + expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) + end end - it "creates an audit log and updates is on each read." do - completion_response = +"" - - model.perform_completion!(prompt, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + context "with functions" do + let(:generic_prompt) do + { + insts: "You can tell me the weather", + input: "Return the weather in Sydney", + tools: [tool], + } end - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first + before { stub_streamed_response(prompt, tool_deltas, tool_call: true) } - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(stream_request_body) - expect(log.raw_response_payload).to be_present - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) + it "waits for the invocation to finish before calling the partial" do + buffered_partial = "" + + model.perform_completion!(dialect, user) do |partial, cancel| + buffered_partial = partial if partial.include?("") + end + + expect(buffered_partial).to eq(invocation) + end end end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index a9431dde..df855e73 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } let(:model_name) { "gemini-pro" } - let(:prompt) do + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } + + let(:tool_payload) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name" }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + ], + required: %w[location unit], + } + end + + let(:request_body) do + model + .default_options + .merge(contents: prompt) + .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] } + .to_json + end + let(:stream_request_body) do + model + .default_options + .merge(contents: prompt) + .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] } + .to_json + end + + let(:tool_deltas) do [ - { role: "system", content: "You are a helpful bot." }, - { role: "user", content: "Write 3 words" }, + { "functionCall" => { name: "get_weather", args: {} } }, + { "functionCall" => { name: "get_weather", args: { location: "" } } }, + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }, ] end - let(:request_body) { model.default_options.merge(contents: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json } + let(:tool_call) do + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } } + end - def response(content) + def response(content, tool_call: false) { candidates: [ { content: { - parts: [{ text: content }], + parts: [(tool_call ? content : { text: content })], role: "model", }, finishReason: "STOP", @@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request( :post, "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: { contents: prompt }) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .with(body: request_body) + .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, finish_reason: nil, tool_call: false) { candidates: [ { content: { - parts: [{ text: delta }], + parts: [(tool_call ? delta : { text: delta })], role: "model", }, finishReason: finish_reason, @@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "STOP") + stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call) else - stream_line(deltas[index]) + stream_line(deltas[index], tool_call: tool_call) end end - chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]") + chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("") WebMock .stub_request( :post, "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: model.default_options.merge(contents: prompt).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index ba29893a..087ca1fc 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) } let(:model_name) { "Llama2-*-chat-hf" } - let(:prompt) { <<~TEXT } - [INST]<>You are a helpful bot.<>[/INST] - [INST]Write 3 words[/INST] - TEXT + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) do + DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name) + end + let(:prompt) { dialect.translate } let(:request_body) do model @@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do .tap do |payload| payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) - payload[:parameters][:return_full_text] = false end .to_json end @@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) payload[:stream] = true - payload[:parameters][:return_full_text] = false end .to_json end @@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do [{ generated_text: content }] end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .with(body: request_body) .to_return(status: 200, body: JSON.dump(response(response_text))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, deltas, finish_reason: nil) +"data: " << { token: { id: 29_889, @@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do logprob: -0.08319092, special: !!finish_reason, }, - generated_text: finish_reason ? response_text : nil, + generated_text: finish_reason ? deltas.join : nil, details: nil, }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: true) + stream_line(deltas[index], deltas, finish_reason: true) else - stream_line(deltas[index]) + stream_line(deltas[index], deltas) end end - chunks = chunks.join("\n\n") + chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") @@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index fa22d461..929c087e 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } let(:model_name) { "gpt-3.5-turbo" } - let(:prompt) do + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } + + let(:tool_deltas) do [ - { role: "system", content: "You are a helpful bot." }, - { role: "user", content: "Write 3 words" }, + { id: "get_weather", name: "get_weather", arguments: {} }, + { id: "get_weather", name: "get_weather", arguments: { location: "" } }, + { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }, ] end - let(:request_body) { model.default_options.merge(messages: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } + let(:tool_call) do + { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } } + end + + let(:request_body) do + model + .default_options + .merge(messages: prompt) + .tap do |b| + b[:tools] = generic_prompt[:tools].map do |t| + { type: "function", tool: t } + end if generic_prompt[:tools] + end + .to_json + end + let(:stream_request_body) do + model + .default_options + .merge(messages: prompt, stream: true) + .tap do |b| + b[:tools] = generic_prompt[:tools].map do |t| + { type: "function", tool: t } + end if generic_prompt[:tools] + end + .to_json + end + + def response(content, tool_call: false) + message_content = + if tool_call + { tool_calls: [{ function: content }] } + else + { content: content } + end - def response(content) { id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", object: "chat.completion", @@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do total_tokens: 499, }, choices: [ - { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + { message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 }, ], } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: { model: model_name, messages: prompt }) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .with(body: request_body) + .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, finish_reason: nil, tool_call: false) + message_content = + if tool_call + { tool_calls: [{ function: delta }] } + else + { content: delta } + end + +"data: " << { id: "chatcmpl-#{SecureRandom.hex}", object: "chat.completion.chunk", created: 1_681_283_881, model: "gpt-3.5-turbo-0301", - choices: [{ delta: { content: delta } }], + choices: [{ delta: message_content }], finish_reason: finish_reason, index: 0, }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "stop_sequence") + stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call) else - stream_line(deltas[index]) + stream_line(deltas[index], tool_call: tool_call) end end - chunks = chunks.join("\n\n") + chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 66f53060..3a67ff22 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -3,7 +3,7 @@ RSpec.describe DiscourseAi::Completions::Llm do subject(:llm) do described_class.new( - DiscourseAi::Completions::Dialects::OrcaStyle.new, + DiscourseAi::Completions::Dialects::OrcaStyle, canned_response, "Upstage-Llama-2-*-instruct-v2", ) diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 78939334..8b115a27 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do expect(response.status).to eq(200) expect(response.parsed_body["suggestions"].first).to eq(translated_text) expect(response.parsed_body["diff"]).to eq(expected_diff) - expect(spy.prompt.last[:content]).to eq(expected_input) + expect(spy.prompt.translate.last[:content]).to eq(expected_input) end end end