diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 82595225..dda3e8a2 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -3,31 +3,99 @@ module DiscourseAi module Completions module Dialects - class ChatGpt - def self.can_translate?(model_name) - %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name) + class ChatGpt < Dialect + class << self + def can_translate?(model_name) + %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end end - def translate(generic_prompt) + def translate open_ai_prompt = [ - { - role: "system", - content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"), - }, + { role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") }, ] - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| open_ai_prompt << { role: "user", content: example_pair.first } open_ai_prompt << { role: "assistant", content: example_pair.second } end end - open_ai_prompt << { role: "user", content: generic_prompt[:input] } + open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context] + + open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input] + + open_ai_prompt end - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer + def tools + return if prompt[:tools].blank? + + prompt[:tools].map { |t| { type: "function", tool: t } } + end + + def conversation_context + return [] if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context.reverse.map do |context| + translated = context.slice(:content) + translated[:role] = context[:type] + + if context[:name] + if translated[:role] == "tool" + translated[:tool_call_id] = context[:name] + else + translated[:name] = context[:name] + end + end + + translated + end + end + + def max_prompt_tokens + # provide a buffer of 120 tokens - our function counting is not + # 100% accurate and getting numbers to align exactly is very hard + buffer = (opts[:max_tokens_to_sample] || 2500) + 50 + + if tools.present? + # note this is about 100 tokens over, OpenAI have a more optimal representation + @function_size ||= self.class.tokenizer.size(tools.to_json.to_s) + buffer += @function_size + end + + model_max_tokens - buffer + end + + private + + def per_message_overhead + # open ai defines about 4 tokens per message of overhead + 4 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) + end + + def model_max_tokens + case model_name + when "gpt-3.5-turbo", "gpt-3.5-turbo-16k" + 16_384 + when "gpt-4" + 8192 + when "gpt-4-32k" + 32_768 + else + 8192 + end end end end diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 07438985..f10c4c49 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -3,25 +3,65 @@ module DiscourseAi module Completions module Dialects - class Claude - def self.can_translate?(model_name) - %w[claude-instant-1 claude-2].include?(model_name) + class Claude < Dialect + class << self + def can_translate?(model_name) + %w[claude-instant-1 claude-2].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::AnthropicTokenizer + end end - def translate(generic_prompt) - claude_prompt = +"Human: #{generic_prompt[:insts]}\n" + def translate + claude_prompt = +"Human: #{prompt[:insts]}\n" - claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples] + claude_prompt << build_tools_prompt if prompt[:tools] - claude_prompt << "#{generic_prompt[:input]}\n" + claude_prompt << build_examples(prompt[:examples]) if prompt[:examples] - claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts] + claude_prompt << conversation_context if prompt[:conversation_context] + + claude_prompt << "#{prompt[:input]}\n" + + claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts] claude_prompt << "Assistant:\n" end - def tokenizer - DiscourseAi::Tokenizer::AnthropicTokenizer + def max_prompt_tokens + 50_000 + end + + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + memo << (context[:type] == "user" ? "Human:" : "Assistant:") + + if context[:type] == "tool" + memo << <<~TEXT + + + + #{context[:name]} + + #{context[:content]} + + + + TEXT + else + memo << " " << context[:content] << "\n" + end + + memo + end end private diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb new file mode 100644 index 00000000..5e6d6b97 --- /dev/null +++ b/lib/completions/dialects/dialect.rb @@ -0,0 +1,160 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Dialect + class << self + def can_translate?(_model_name) + raise NotImplemented + end + + def dialect_for(model_name) + dialects = [ + DiscourseAi::Completions::Dialects::Claude, + DiscourseAi::Completions::Dialects::Llama2Classic, + DiscourseAi::Completions::Dialects::ChatGpt, + DiscourseAi::Completions::Dialects::OrcaStyle, + DiscourseAi::Completions::Dialects::Gemini, + ] + dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d| + d.can_translate?(model_name) + end + end + + def tokenizer + raise NotImplemented + end + end + + def initialize(generic_prompt, model_name, opts: {}) + @prompt = generic_prompt + @model_name = model_name + @opts = opts + end + + def translate + raise NotImplemented + end + + def tools + tools = +"" + + prompt[:tools].each do |function| + parameters = +"" + if function[:parameters].present? + function[:parameters].each do |parameter| + parameters << <<~PARAMETER + + #{parameter[:name]} + #{parameter[:type]} + #{parameter[:description]} + #{parameter[:required]} + PARAMETER + if parameter[:enum] + parameters << "#{parameter[:enum].join(",")}\n" + end + parameters << "\n" + end + end + + tools << <<~TOOLS + + #{function[:name]} + #{function[:description]} + + #{parameters} + + TOOLS + end + + tools + end + + def conversation_context + raise NotImplemented + end + + def max_prompt_tokens + raise NotImplemented + end + + private + + attr_reader :prompt, :model_name, :opts + + def trim_context(conversation_context) + prompt_limit = max_prompt_tokens + current_token_count = calculate_token_count_without_context + + conversation_context.reduce([]) do |memo, context| + break(memo) if current_token_count >= prompt_limit + + dupped_context = context.dup + + message_tokens = calculate_message_token(dupped_context) + + # Trimming content to make sure we respect token limit. + while dupped_context[:content].present? && + message_tokens + current_token_count + per_message_overhead > prompt_limit + dupped_context[:content] = dupped_context[:content][0..-100] || "" + message_tokens = calculate_message_token(dupped_context) + end + + next(memo) if dupped_context[:content].blank? + + current_token_count += message_tokens + per_message_overhead + + memo << dupped_context + end + end + + def calculate_token_count_without_context + tokenizer = self.class.tokenizer + + examples_count = + prompt[:examples].to_a.sum do |pair| + tokenizer.size(pair.join) + (per_message_overhead * 2) + end + input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead + + examples_count + input_count + + prompt + .except(:conversation_context, :tools, :examples, :input) + .sum { |_, v| tokenizer.size(v) + per_message_overhead } + end + + def per_message_overhead + 0 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s) + end + + def build_tools_prompt + return "" if prompt[:tools].blank? + + <<~TEXT + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{tools} + TEXT + end + end + end + end +end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 49fa716e..b6b938bf 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -3,34 +3,91 @@ module DiscourseAi module Completions module Dialects - class Gemini - def self.can_translate?(model_name) - %w[gemini-pro].include?(model_name) + class Gemini < Dialect + class << self + def can_translate?(model_name) + %w[gemini-pro].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer + end end - def translate(generic_prompt) + def translate gemini_prompt = [ { role: "user", parts: { - text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"), + text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"), }, }, { role: "model", parts: { text: "Ok." } }, ] - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| gemini_prompt << { role: "user", parts: { text: example_pair.first } } gemini_prompt << { role: "model", parts: { text: example_pair.second } } end end - gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } } + gemini_prompt.concat!(conversation_context) if prompt[:conversation_context] + + gemini_prompt << { role: "user", parts: { text: prompt[:input] } } end - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer + def tools + return if prompt[:tools].blank? + + translated_tools = + prompt[:tools].map do |t| + required_fields = [] + tool = t.dup + + tool[:parameters] = t[:parameters].map do |p| + required_fields << p[:name] if p[:required] + + p.except(:required) + end + + tool.merge(required: required_fields) + end + + [{ function_declarations: translated_tools }] + end + + def conversation_context + return [] if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context.reverse.map do |context| + translated = {} + translated[:role] = (context[:type] == "user" ? "user" : "model") + + part = {} + + if context[:type] == "tool" + part["functionResponse"] = { name: context[:name], content: context[:content] } + else + part[:text] = context[:content] + end + + translated[:parts] = [part] + + translated + end + end + + def max_prompt_tokens + 16_384 # 50% of model tokens + end + + protected + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index 542e5f57..0470a61c 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -3,27 +3,72 @@ module DiscourseAi module Completions module Dialects - class Llama2Classic - def self.can_translate?(model_name) - %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) + class Llama2Classic < Dialect + class << self + def can_translate?(model_name) + %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::Llama2Tokenizer + end end - def translate(generic_prompt) - llama2_prompt = - +"[INST]<>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<>[/INST]\n" + def translate + llama2_prompt = +<<~TEXT + [INST] + <> + #{prompt[:insts]} + #{build_tools_prompt}#{prompt[:post_insts]} + <> + [/INST] + TEXT - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| llama2_prompt << "[INST]#{example_pair.first}[/INST]\n" llama2_prompt << "#{example_pair.second}\n" end end - llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n" + llama2_prompt << conversation_context if prompt[:conversation_context].present? + + llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n" end - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + if context[:type] == "tool" + memo << <<~TEXT + [INST] + + + #{context[:name]} + + #{context[:content]} + + + + [/INST] + TEXT + elsif context[:type] == "assistant" + memo << "[INST]" << context[:content] << "[/INST]\n" + else + memo << context[:content] << "\n" + end + + memo + end + end + + def max_prompt_tokens + SiteSetting.ai_hugging_face_token_limit end end end diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index 3aa11609..fd76f3b5 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -3,29 +3,68 @@ module DiscourseAi module Completions module Dialects - class OrcaStyle - def self.can_translate?(model_name) - %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) + class OrcaStyle < Dialect + class << self + def can_translate?(model_name) + %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::Llama2Tokenizer + end end - def translate(generic_prompt) - orca_style_prompt = - +"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n" + def translate + orca_style_prompt = +<<~TEXT + ### System: + #{prompt[:insts]} + #{build_tools_prompt}#{prompt[:post_insts]} + TEXT - if generic_prompt[:examples] - generic_prompt[:examples].each do |example_pair| + if prompt[:examples] + prompt[:examples].each do |example_pair| orca_style_prompt << "### User:\n#{example_pair.first}\n" orca_style_prompt << "### Assistant:\n#{example_pair.second}\n" end end - orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n" + orca_style_prompt << "### User:\n#{prompt[:input]}\n" orca_style_prompt << "### Assistant:\n" end - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer + def conversation_context + return "" if prompt[:conversation_context].blank? + + trimmed_context = trim_context(prompt[:conversation_context]) + + trimmed_context + .reverse + .reduce(+"") do |memo, context| + memo << (context[:type] == "user" ? "### User:" : "### Assistant:") + + if context[:type] == "tool" + memo << <<~TEXT + + + + #{context[:name]} + + #{context[:content]} + + + + TEXT + else + memo << " " << context[:content] << "\n" + end + + memo + end + end + + def max_prompt_tokens + SiteSetting.ai_hugging_face_token_limit end end end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 5216d4e7..3846d7e4 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -22,7 +22,7 @@ module DiscourseAi @uri ||= URI("https://api.anthropic.com/v1/complete") end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options .merge(model_params) .merge(prompt: prompt) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 0e8cf68f..902a375c 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -37,7 +37,7 @@ module DiscourseAi URI(api_url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options.merge(prompt: prompt).merge(model_params) end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 61846e23..ef92a162 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -30,9 +30,11 @@ module DiscourseAi @tokenizer = tokenizer end - def perform_completion!(prompt, user, model_params = {}) + def perform_completion!(dialect, user, model_params = {}) @streaming_mode = block_given? + prompt = dialect.translate + Net::HTTP.start( model_uri.host, model_uri.port, @@ -43,7 +45,10 @@ module DiscourseAi ) do |http| response_data = +"" response_raw = +"" - request_body = prepare_payload(prompt, model_params).to_json + + # Needed to response token calculations. Cannot rely on response_data due to function buffering. + partials_raw = +"" + request_body = prepare_payload(prompt, model_params, dialect).to_json request = prepare_request(request_body) @@ -66,6 +71,15 @@ module DiscourseAi if !@streaming_mode response_raw = response.read_body response_data = extract_completion_from(response_raw) + partials_raw = response_data.to_s + + if has_tool?("", response_data) + function_buffer = build_buffer # Nokogiri document + function_buffer = add_to_buffer(function_buffer, "", response_data) + + response_data = +function_buffer.at("function_calls").to_s + response_data << "\n" + end return response_data end @@ -75,6 +89,7 @@ module DiscourseAi cancel = lambda { cancelled = true } leftover = "" + function_buffer = build_buffer # Nokogiri document response.read_body do |chunk| if cancelled @@ -85,6 +100,12 @@ module DiscourseAi decoded_chunk = decode(chunk) response_raw << decoded_chunk + # Buffering for extremely slow streaming. + if (leftover + decoded_chunk).length < "data: [DONE]".length + leftover += decoded_chunk + next + end + partials_from(leftover + decoded_chunk).each do |raw_partial| next if cancelled next if raw_partial.blank? @@ -93,11 +114,27 @@ module DiscourseAi partial = extract_completion_from(raw_partial) next if partial.nil? leftover = "" - response_data << partial - yield partial, cancel if partial + if has_tool?(response_data, partial) + function_buffer = add_to_buffer(function_buffer, response_data, partial) + + if buffering_finished?(dialect.tools, function_buffer) + invocation = +function_buffer.at("function_calls").to_s + invocation << "\n" + + partials_raw << partial.to_s + response_data << invocation + + yield invocation, cancel + end + else + partials_raw << partial + response_data << partial + + yield partial, cancel if partial + end rescue JSON::ParserError - leftover = raw_partial + leftover += decoded_chunk end end end @@ -109,7 +146,7 @@ module DiscourseAi ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(response_data) + log.response_tokens = tokenizer.size(partials_raw) log.save! if Rails.env.development? @@ -165,6 +202,40 @@ module DiscourseAi def extract_prompt_for_tokenizer(prompt) prompt end + + def build_buffer + Nokogiri::HTML5.fragment(<<~TEXT) + + + + + + + + TEXT + end + + def has_tool?(response, partial) + (response + partial).include?("") + end + + def add_to_buffer(function_buffer, response_data, partial) + new_buffer = Nokogiri::HTML5.fragment(response_data + partial) + if tool_name = new_buffer.at("tool_name").text + if new_buffer.at("tool_id").nil? + tool_id_node = + Nokogiri::HTML5::DocumentFragment.parse("\n#{tool_name}") + + new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node) + end + end + + new_buffer + end + + def buffering_finished?(_available_functions, buffer) + buffer.to_s.include?("") + end end end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index 83300fcc..56d2b913 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -31,9 +31,14 @@ module DiscourseAi cancelled = false cancel_fn = lambda { cancelled = true } - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + # We buffer and return tool invocations in one go. + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end else response @@ -43,6 +48,12 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + private + + def is_tool?(response) + Nokogiri::HTML5.fragment(response).at("function_calls").present? + end end end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 7738085c..8fa8a2d3 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -25,8 +25,11 @@ module DiscourseAi URI(url) end - def prepare_payload(prompt, model_params) - default_options.merge(model_params).merge(contents: prompt) + def prepare_payload(prompt, model_params, dialect) + default_options + .merge(model_params) + .merge(contents: prompt) + .tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? } end def prepare_request(payload) @@ -36,25 +39,72 @@ module DiscourseAi end def extract_completion_from(response_raw) - if @streaming_mode - parsed = response_raw - else - parsed = JSON.parse(response_raw, symbolize_names: true) - end + parsed = JSON.parse(response_raw, symbolize_names: true) - completion = dig_text(parsed).to_s + response_h = parsed.dig(:candidates, 0, :content, :parts, 0) + + has_function_call = response_h.dig(:functionCall).present? + has_function_call ? response_h[:functionCall] : response_h.dig(:text) end def partials_from(decoded_chunk) - JSON.parse(decoded_chunk, symbolize_names: true) + decoded_chunk + .split("\n") + .map do |line| + if line == "," + nil + elsif line.starts_with?("[") + line[1..-1] + elsif line.ends_with?("]") + line[0..-1] + else + line + end + end + .compact_blank end def extract_prompt_for_tokenizer(prompt) prompt.to_s end - def dig_text(response) - response.dig(:candidates, 0, :content, :parts, 0, :text) + def has_tool?(_response_data, partial) + partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + end + + def add_to_buffer(function_buffer, _response_data, partial) + if partial[:name].present? + function_buffer.at("tool_name").content = partial[:name] + function_buffer.at("tool_id").content = partial[:name] + end + + if partial[:args] + argument_fragments = + partial[:args].reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + function_buffer.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + end + + function_buffer + end + + def buffering_finished?(available_functions, buffer) + tool_name = buffer.at("tool_name")&.text + return false if tool_name.blank? + + signature = + available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name } + + signature[:parameters].reduce(true) do |memo, param| + param_present = buffer.at(param[:name]).present? + next(memo) if param_present || !signature[:required].include?(param[:name]) + + memo && param_present + end end end end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index d5d10cea..819cec47 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -11,7 +11,7 @@ module DiscourseAi end def default_options - { parameters: { repetition_penalty: 1.1, temperature: 0.7 } } + { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } end def provider_id @@ -24,7 +24,7 @@ module DiscourseAi URI(SiteSetting.ai_hugging_face_api_url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, _dialect) default_options .merge(inputs: prompt) .tap do |payload| @@ -33,7 +33,6 @@ module DiscourseAi token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) - payload[:parameters][:return_full_text] = false payload[:stream] = true if @streaming_mode end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 65b01314..b0664cd0 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -37,11 +37,14 @@ module DiscourseAi URI(url) end - def prepare_payload(prompt, model_params) + def prepare_payload(prompt, model_params, dialect) default_options .merge(model_params) .merge(messages: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + .tap do |payload| + payload[:stream] = true if @streaming_mode + payload[:tools] = dialect.tools if dialect.tools.present? + end end def prepare_request(payload) @@ -62,15 +65,12 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - ( - if @streaming_mode - parsed.dig(:choices, 0, :delta, :content) - else - parsed.dig(:choices, 0, :message, :content) - end - ).to_s + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + has_function_call = response_h.dig(:tool_calls).present? + has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content) end def partials_from(decoded_chunk) @@ -86,6 +86,42 @@ module DiscourseAi def extract_prompt_for_tokenizer(prompt) prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end + + def has_tool?(_response_data, partial) + partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + end + + def add_to_buffer(function_buffer, _response_data, partial) + function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? + function_buffer.at("tool_id").content = partial[:id] if partial[:id].present? + + if partial[:arguments] + argument_fragments = + partial[:arguments].reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + function_buffer.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + end + + function_buffer + end + + def buffering_finished?(available_functions, buffer) + tool_name = buffer.at("tool_name")&.text + return false if tool_name.blank? + + signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool] + + signature[:parameters].reduce(true) do |memo, param| + param_present = buffer.at(param[:name]).present? + next(memo) if param_present && !param[:required] + + memo && param_present + end + end end end end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index f2cbc0e2..16e1c96b 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -24,35 +24,26 @@ module DiscourseAi end def self.proxy(model_name) - dialects = [ - DiscourseAi::Completions::Dialects::Claude, - DiscourseAi::Completions::Dialects::Llama2Classic, - DiscourseAi::Completions::Dialects::ChatGpt, - DiscourseAi::Completions::Dialects::OrcaStyle, - DiscourseAi::Completions::Dialects::Gemini, - ] + dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name) - dialect = - dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new - - return new(dialect, @canned_response, model_name) if @canned_response + return new(dialect_klass, @canned_response, model_name) if @canned_response gateway = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new( model_name, - dialect.tokenizer, + dialect_klass.tokenizer, ) - new(dialect, gateway, model_name) + new(dialect_klass, gateway, model_name) end - def initialize(dialect, gateway, model_name) - @dialect = dialect + def initialize(dialect_klass, gateway, model_name) + @dialect_klass = dialect_klass @gateway = gateway @model_name = model_name end - delegate :tokenizer, to: :dialect + delegate :tokenizer, to: :dialect_klass # @param generic_prompt { Hash } - Prompt using our generic format. # We use the following keys from the hash: @@ -60,23 +51,64 @@ module DiscourseAi # - input: String containing user input # - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]]. # - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt. + # - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model. + # We translate the array in reverse order, meaning the first element would be the most recent message in the conversation. + # Example: + # + # [ + # { type: "user", name: "user1", content: "This is a new message by a user" }, + # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + # { type: "tool", name: "tool_id", content: "I'm a tool result" }, + # ] + # + # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example: + # + # { + # name: "get_weather", + # description: "Get the weather in a city", + # parameters: [ + # { name: "location", type: "string", description: "the city name", required: true }, + # { + # name: "unit", + # type: "string", + # description: "the unit of measurement celcius c or fahrenheit f", + # enum: %w[c f], + # required: true, + # }, + # ], + # } # # @param user { User } - User requesting the summary. # # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # # @returns { String } - Completion result. + # + # When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool, + # even if you passed a partial_read_blk block. Invocations are strings that look like this: + # + # + # + # get_weather + # get_weather + # + # Sydney + # c + # + # + # + # def completion!(generic_prompt, user, &partial_read_blk) - prompt = dialect.translate(generic_prompt) - model_params = generic_prompt.dig(:params, model_name) || {} - gateway.perform_completion!(prompt, user, model_params, &partial_read_blk) + dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params) + + gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end private - attr_reader :dialect, :gateway, :model_name + attr_reader :dialect_klass, :gateway, :model_name end end end diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index 27b56b49..2792dbf6 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "gpt-4") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do TEXT post_insts: "Please put the translation between tags and separate each title with a comma.", + tools: [tool], } end @@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do { role: "user", content: prompt[:input] }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*open_ai_version) end @@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do { role: "user", content: prompt[:input] }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*open_ai_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context).to eq( + [ + { role: "tool", content: context.last[:content], tool_call_id: context.last[:name] }, + { role: "assistant", content: context.second[:content] }, + { role: "user", content: context.first[:content], name: context.first[:name] }, + ], + ) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1000 + + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.last[:content].length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "returns a list of available tools" do + open_ai_tool_f = { type: "function", tool: tool } + + expect(subject.tools).to contain_exactly(open_ai_tool_f) + end + end end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index d26dd570..1107814b 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Claude do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "claude-2") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(anthropic_version) end @@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(anthropic_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + anthropic_version = <<~TEXT + Human: #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:input]} + #{prompt[:post_insts]} + Assistant: + TEXT + + translated = dialect.translate expect(translated).to eq(anthropic_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + Assistant: + + + tool_id + + #{context.last[:content]} + + + + Assistant: #{context.second[:content]} + Human: #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 10_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates tools to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 84aec2d3..0a5266b7 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Gemini do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "gemini-pro") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do TEXT post_insts: "Please put the translation between tags and separate each title with a comma.", + tools: [tool], } end @@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { role: "user", parts: { text: prompt[:input] } }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(gemini_version) end @@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do { role: "user", parts: { text: prompt[:input] } }, ] - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to contain_exactly(*gemini_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context).to eq( + [ + { + role: "model", + parts: [ + { + "functionResponse" => { + name: context.last[:name], + content: context.last[:content], + }, + }, + ], + }, + { role: "model", parts: [{ text: context.second[:content] }] }, + { role: "user", parts: [{ text: context.first[:content] }] }, + ], + ) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1000 + + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.last.dig(:parts, 0, :text).length).to be < + context.last[:content].length + end + end + + describe "#tools" do + it "returns a list of available tools" do + gemini_tools = { + function_declarations: [ + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name" }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + ], + required: %w[location unit], + }, + ], + } + + expect(subject.tools).to contain_exactly(gemini_tools) + end + end end diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb index 2b1d93a2..81d088e6 100644 --- a/spec/lib/completions/dialects/llama2_classic_spec.rb +++ b/spec/lib/completions/dialects/llama2_classic_spec.rb @@ -1,7 +1,24 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end let(:prompt) do { @@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do describe "#translate" do it "translates a prompt written in our generic format to the Llama2 format" do llama2_classic_version = <<~TEXT - [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST] + [INST] + <> + #{prompt[:insts]} + #{prompt[:post_insts]} + <> + [/INST] [INST]#{prompt[:input]}[/INST] TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(llama2_classic_version) end @@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do ] llama2_classic_version = <<~TEXT - [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST] + [INST] + <> + #{prompt[:insts]} + #{prompt[:post_insts]} + <> + [/INST] [INST]#{prompt[:examples][0][0]}[/INST] #{prompt[:examples][0][1]} [INST]#{prompt[:input]}[/INST] TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(llama2_classic_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + llama2_classic_version = <<~TEXT + [INST] + <> + #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:post_insts]} + <> + [/INST] + [INST]#{prompt[:input]}[/INST] + TEXT + + translated = dialect.translate expect(translated).to eq(llama2_classic_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + [INST] + + + tool_id + + #{context.last[:content]} + + + + [/INST] + [INST]#{context.second[:content]}[/INST] + #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates functions to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb index 411a84a8..d27dc9d3 100644 --- a/spec/lib/completions/dialects/orca_style_spec.rb +++ b/spec/lib/completions/dialects/orca_style_spec.rb @@ -1,44 +1,62 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do - subject(:dialect) { described_class.new } + subject(:dialect) { described_class.new(prompt, "StableBeluga2") } + + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end + + let(:prompt) do + { + insts: <<~TEXT, + I want you to act as a title generator for written pieces. I will provide you with a text, + and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words, + and ensure that the meaning is maintained. Replies will utilize the language type of the topic. + TEXT + input: <<~TEXT, + Here is the text, inside XML tags: + + To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, + discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer + defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry. + + Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires, + a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and + slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he + dies so that a scene may be repeated. + + TEXT + post_insts: + "Please put the translation between tags and separate each title with a comma.", + } + end describe "#translate" do - let(:prompt) do - { - insts: <<~TEXT, - I want you to act as a title generator for written pieces. I will provide you with a text, - and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words, - and ensure that the meaning is maintained. Replies will utilize the language type of the topic. - TEXT - input: <<~TEXT, - Here is the text, inside XML tags: - - To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, - discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer - defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry. - - Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires, - a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and - slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he - dies so that a scene may be repeated. - - TEXT - post_insts: - "Please put the translation between tags and separate each title with a comma.", - } - end - it "translates a prompt written in our generic format to the Open AI format" do orca_style_version = <<~TEXT ### System: - #{[prompt[:insts], prompt[:post_insts]].join("\n")} + #{prompt[:insts]} + #{prompt[:post_insts]} ### User: #{prompt[:input]} ### Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate expect(translated).to eq(orca_style_version) end @@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do orca_style_version = <<~TEXT ### System: - #{[prompt[:insts], prompt[:post_insts]].join("\n")} + #{prompt[:insts]} + #{prompt[:post_insts]} ### User: #{prompt[:examples][0][0]} ### Assistant: @@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do ### Assistant: TEXT - translated = dialect.translate(prompt) + translated = dialect.translate + + expect(translated).to eq(orca_style_version) + end + + it "include tools inside the prompt" do + prompt[:tools] = [tool] + + orca_style_version = <<~TEXT + ### System: + #{prompt[:insts]} + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. Only invoke one function at a time and wait for the results before invoking another function: + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + + Here are the tools available: + + + #{dialect.tools} + #{prompt[:post_insts]} + ### User: + #{prompt[:input]} + ### Assistant: + TEXT + + translated = dialect.translate expect(translated).to eq(orca_style_version) end end + + describe "#conversation_context" do + let(:context) do + [ + { type: "user", name: "user1", content: "This is a new message by a user" }, + { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { type: "tool", name: "tool_id", content: "I'm a tool result" }, + ] + end + + it "adds conversation in reverse order (first == newer)" do + prompt[:conversation_context] = context + + expected = <<~TEXT + ### Assistant: + + + tool_id + + #{context.last[:content]} + + + + ### Assistant: #{context.second[:content]} + ### User: #{context.first[:content]} + TEXT + + translated_context = dialect.conversation_context + + expect(translated_context).to eq(expected) + end + + it "trims content if it's getting too long" do + context.last[:content] = context.last[:content] * 1_000 + prompt[:conversation_context] = context + + translated_context = dialect.conversation_context + + expect(translated_context.length).to be < context.last[:content].length + end + end + + describe "#tools" do + it "translates tools to the tool syntax" do + prompt[:tools] = [tool] + + translated_tool = <<~TEXT + + get_weather + Get the weather in a city + + + location + string + the city name + true + + + unit + string + the unit of measurement celcius c or fahrenheit f + true + c,f + + + + TEXT + + expect(dialect.tools).to eq(translated_tool) + end + end end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index d0309e2f..0a57ad29 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) } let(:model_name) { "claude-2" } - let(:prompt) { "Human: write 3 words\n\n" } + let(:generic_prompt) { { insts: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } @@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt).to_json) + .with(body: request_body) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) @@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end - chunks = chunks.join("\n\n") + chunks = chunks.join("\n\n").split("") WebMock .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 5c0cb8cc..2c866898 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do let(:model_name) { "claude-2" } let(:bedrock_name) { "claude-v2" } - let(:prompt) { "Human: write 3 words\n\n" } + let(:generic_prompt) { { insts: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json } + let(:stream_request_body) { request_body } before do SiteSetting.ai_bedrock_access_key_id = "123456" @@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do SiteSetting.ai_bedrock_region = "us-east-1" end - # Copied from https://github.com/bblimke/webmock/issues/629 - # Workaround for stubbing a streamed response - before do - mocked_http = - Class.new(Net::HTTP) do - def request(*) - super do |response| - response.instance_eval do - def read_body(*, &block) - if block_given? - @body.each(&block) - else - super - end - end - end - - yield response if block_given? - - response - end - end - end - - @original_net_http = Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, mocked_http) - end - - after do - Net.send(:remove_const, :HTTP) - Net.send(:const_set, :HTTP, @original_net_http) - end - def response(content) { completion: content, @@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request( :post, @@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do encoder.encode(message) end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) @@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb index 6ca86070..0de7ef73 100644 --- a/spec/lib/completions/endpoints/endpoint_examples.rb +++ b/spec/lib/completions/endpoints/endpoint_examples.rb @@ -1,70 +1,177 @@ # frozen_string_literal: true RSpec.shared_examples "an endpoint that can communicate with a completion service" do + # Copied from https://github.com/bblimke/webmock/issues/629 + # Workaround for stubbing a streamed response + before do + mocked_http = + Class.new(Net::HTTP) do + def request(*) + super do |response| + response.instance_eval do + def read_body(*, &block) + if block_given? + @body.each(&block) + else + super + end + end + end + + yield response if block_given? + + response + end + end + end + + @original_net_http = Net.send(:remove_const, :HTTP) + Net.send(:const_set, :HTTP, mocked_http) + end + + after do + Net.send(:remove_const, :HTTP) + Net.send(:const_set, :HTTP, @original_net_http) + end + describe "#perform_completion!" do fab!(:user) { Fabricate(:user) } - let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } + let(:tool) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + required: true, + }, + ], + } + end + + let(:invocation) { <<~TEXT } + + + get_weather + get_weather + + Sydney + c + + + + TEXT context "when using regular mode" do - before { stub_response(prompt, response_text) } + context "with simple prompts" do + let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } - it "can complete a trivial prompt" do - completion_response = model.perform_completion!(prompt, user) + before { stub_response(prompt, response_text) } - expect(completion_response).to eq(response_text) + it "can complete a trivial prompt" do + completion_response = model.perform_completion!(dialect, user) + + expect(completion_response).to eq(response_text) + end + + it "creates an audit log for the request" do + model.perform_completion!(dialect, user) + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + response_body = response(response_text).to_json + + expect(log.provider_id).to eq(model.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to eq(request_body) + expect(log.raw_response_payload).to eq(response_body) + expect(log.request_tokens).to eq(model.prompt_size(prompt)) + expect(log.response_tokens).to eq(model.tokenizer.size(response_text)) + end end - it "creates an audit log for the request" do - model.perform_completion!(prompt, user) + context "with functions" do + let(:generic_prompt) do + { + insts: "You can tell me the weather", + input: "Return the weather in Sydney", + tools: [tool], + } + end - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first + before { stub_response(prompt, tool_call, tool_call: true) } - response_body = response(response_text).to_json + it "returns a function invocation" do + completion_response = model.perform_completion!(dialect, user) - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(request_body) - expect(log.raw_response_payload).to eq(response_body) - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(response_text)) + expect(completion_response).to eq(invocation) + end end end context "when using stream mode" do - let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } + context "with simple prompts" do + let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } - before { stub_streamed_response(prompt, deltas) } + before { stub_streamed_response(prompt, deltas) } - it "can complete a trivial prompt" do - completion_response = +"" + it "can complete a trivial prompt" do + completion_response = +"" - model.perform_completion!(prompt, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + model.perform_completion!(dialect, user) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(completion_response).to eq(deltas[0...-1].join) end - expect(completion_response).to eq(deltas[0...-1].join) + it "creates an audit log and updates is on each read." do + completion_response = +"" + + model.perform_completion!(dialect, user) do |partial, cancel| + completion_response << partial + cancel.call if completion_response.split(" ").length == 2 + end + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + expect(log.provider_id).to eq(model.provider_id) + expect(log.user_id).to eq(user.id) + expect(log.raw_request_payload).to eq(stream_request_body) + expect(log.raw_response_payload).to be_present + expect(log.request_tokens).to eq(model.prompt_size(prompt)) + expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) + end end - it "creates an audit log and updates is on each read." do - completion_response = +"" - - model.perform_completion!(prompt, user) do |partial, cancel| - completion_response << partial - cancel.call if completion_response.split(" ").length == 2 + context "with functions" do + let(:generic_prompt) do + { + insts: "You can tell me the weather", + input: "Return the weather in Sydney", + tools: [tool], + } end - expect(AiApiAuditLog.count).to eq(1) - log = AiApiAuditLog.first + before { stub_streamed_response(prompt, tool_deltas, tool_call: true) } - expect(log.provider_id).to eq(model.provider_id) - expect(log.user_id).to eq(user.id) - expect(log.raw_request_payload).to eq(stream_request_body) - expect(log.raw_response_payload).to be_present - expect(log.request_tokens).to eq(model.prompt_size(prompt)) - expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) + it "waits for the invocation to finish before calling the partial" do + buffered_partial = "" + + model.perform_completion!(dialect, user) do |partial, cancel| + buffered_partial = partial if partial.include?("") + end + + expect(buffered_partial).to eq(invocation) + end end end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index a9431dde..df855e73 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } let(:model_name) { "gemini-pro" } - let(:prompt) do + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } + + let(:tool_payload) do + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name" }, + { + name: "unit", + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + ], + required: %w[location unit], + } + end + + let(:request_body) do + model + .default_options + .merge(contents: prompt) + .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] } + .to_json + end + let(:stream_request_body) do + model + .default_options + .merge(contents: prompt) + .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] } + .to_json + end + + let(:tool_deltas) do [ - { role: "system", content: "You are a helpful bot." }, - { role: "user", content: "Write 3 words" }, + { "functionCall" => { name: "get_weather", args: {} } }, + { "functionCall" => { name: "get_weather", args: { location: "" } } }, + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }, ] end - let(:request_body) { model.default_options.merge(contents: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json } + let(:tool_call) do + { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } } + end - def response(content) + def response(content, tool_call: false) { candidates: [ { content: { - parts: [{ text: content }], + parts: [(tool_call ? content : { text: content })], role: "model", }, finishReason: "STOP", @@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request( :post, "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: { contents: prompt }) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .with(body: request_body) + .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, finish_reason: nil, tool_call: false) { candidates: [ { content: { - parts: [{ text: delta }], + parts: [(tool_call ? delta : { text: delta })], role: "model", }, finishReason: finish_reason, @@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "STOP") + stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call) else - stream_line(deltas[index]) + stream_line(deltas[index], tool_call: tool_call) end end - chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]") + chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("") WebMock .stub_request( :post, "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", ) - .with(body: model.default_options.merge(contents: prompt).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index ba29893a..087ca1fc 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) } let(:model_name) { "Llama2-*-chat-hf" } - let(:prompt) { <<~TEXT } - [INST]<>You are a helpful bot.<>[/INST] - [INST]Write 3 words[/INST] - TEXT + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) do + DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name) + end + let(:prompt) { dialect.translate } let(:request_body) do model @@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do .tap do |payload| payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) - payload[:parameters][:return_full_text] = false end .to_json end @@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) payload[:stream] = true - payload[:parameters][:return_full_text] = false end .to_json end @@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do [{ generated_text: content }] end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .with(body: request_body) .to_return(status: 200, body: JSON.dump(response(response_text))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, deltas, finish_reason: nil) +"data: " << { token: { id: 29_889, @@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do logprob: -0.08319092, special: !!finish_reason, }, - generated_text: finish_reason ? response_text : nil, + generated_text: finish_reason ? deltas.join : nil, details: nil, }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: true) + stream_line(deltas[index], deltas, finish_reason: true) else - stream_line(deltas[index]) + stream_line(deltas[index], deltas) end end - chunks = chunks.join("\n\n") + chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") @@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do .to_return(status: 200, body: chunks) end + let(:tool_deltas) { [" + + get_weather + + Sydney + c + + + + REPLY + + + get_weather + + Sydney + c + + + + REPLY + + let(:tool_call) { invocation } + it_behaves_like "an endpoint that can communicate with a completion service" end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index fa22d461..929c087e 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } let(:model_name) { "gpt-3.5-turbo" } - let(:prompt) do + let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } } + let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) } + let(:prompt) { dialect.translate } + + let(:tool_deltas) do [ - { role: "system", content: "You are a helpful bot." }, - { role: "user", content: "Write 3 words" }, + { id: "get_weather", name: "get_weather", arguments: {} }, + { id: "get_weather", name: "get_weather", arguments: { location: "" } }, + { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }, ] end - let(:request_body) { model.default_options.merge(messages: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } + let(:tool_call) do + { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } } + end + + let(:request_body) do + model + .default_options + .merge(messages: prompt) + .tap do |b| + b[:tools] = generic_prompt[:tools].map do |t| + { type: "function", tool: t } + end if generic_prompt[:tools] + end + .to_json + end + let(:stream_request_body) do + model + .default_options + .merge(messages: prompt, stream: true) + .tap do |b| + b[:tools] = generic_prompt[:tools].map do |t| + { type: "function", tool: t } + end if generic_prompt[:tools] + end + .to_json + end + + def response(content, tool_call: false) + message_content = + if tool_call + { tool_calls: [{ function: content }] } + else + { content: content } + end - def response(content) { id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", object: "chat.completion", @@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do total_tokens: 499, }, choices: [ - { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + { message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 }, ], } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: { model: model_name, messages: prompt }) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .with(body: request_body) + .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end - def stream_line(delta, finish_reason: nil) + def stream_line(delta, finish_reason: nil, tool_call: false) + message_content = + if tool_call + { tool_calls: [{ function: delta }] } + else + { content: delta } + end + +"data: " << { id: "chatcmpl-#{SecureRandom.hex}", object: "chat.completion.chunk", created: 1_681_283_881, model: "gpt-3.5-turbo-0301", - choices: [{ delta: { content: delta } }], + choices: [{ delta: message_content }], finish_reason: finish_reason, index: 0, }.to_json end - def stub_streamed_response(prompt, deltas) + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "stop_sequence") + stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call) else - stream_line(deltas[index]) + stream_line(deltas[index], tool_call: tool_call) end end - chunks = chunks.join("\n\n") + chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) + .with(body: stream_request_body) .to_return(status: 200, body: chunks) end diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 66f53060..3a67ff22 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -3,7 +3,7 @@ RSpec.describe DiscourseAi::Completions::Llm do subject(:llm) do described_class.new( - DiscourseAi::Completions::Dialects::OrcaStyle.new, + DiscourseAi::Completions::Dialects::OrcaStyle, canned_response, "Upstage-Llama-2-*-instruct-v2", ) diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 78939334..8b115a27 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do expect(response.status).to eq(200) expect(response.parsed_body["suggestions"].first).to eq(translated_text) expect(response.parsed_body["diff"]).to eq(expected_diff) - expect(spy.prompt.last[:content]).to eq(expected_input) + expect(spy.prompt.translate.last[:content]).to eq(expected_input) end end end