From 4f1a3effe08aa0ea234b8185144dcb0d703da2e6 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Tue, 7 May 2024 10:02:16 -0300 Subject: [PATCH] REFACTOR: Migrate Vllm/TGI-served models to the OpenAI format. (#588) Both endpoints provide OpenAI-compatible servers. The only difference is that Vllm doesn't support passing tools as a separate parameter. Even if the tool param is supported, it ultimately relies on the model's ability to handle native functions, which is not the case with the models we have today. As a part of this change, we are dropping support for StableBeluga/Llama2 models. They don't have a chat_template, meaning the new API can translate them. These changes let us remove some of our existing dialects and are a first step in our plan to support any LLM by defining them as data-driven concepts. I rewrote the "translate" method to use a template method and extracted the tool support strategies into its classes to simplify the code. Finally, these changes bring support for Ollama when running in dev mode. It only works with Mistral for now, but it will change soon.. --- app/models/ai_api_audit_log.rb | 1 + config/settings.yml | 3 + lib/ai_bot/bot.rb | 13 +- lib/automation.rb | 4 +- lib/completions/dialects/chat_gpt.rb | 113 +++++-------- lib/completions/dialects/claude.rb | 57 +++---- lib/completions/dialects/command.rb | 89 +++++------ lib/completions/dialects/dialect.rb | 149 +++++------------- lib/completions/dialects/fake.rb | 8 +- lib/completions/dialects/gemini.rb | 99 ++++++------ lib/completions/dialects/llama2_classic.rb | 68 -------- lib/completions/dialects/mistral.rb | 57 +++++++ lib/completions/dialects/mixtral.rb | 57 ------- lib/completions/dialects/open_ai_tools.rb | 62 ++++++++ lib/completions/dialects/orca_style.rb | 59 ------- lib/completions/dialects/xml_tools.rb | 125 +++++++++++++++ lib/completions/endpoints/anthropic.rb | 2 +- lib/completions/endpoints/aws_bedrock.rb | 2 +- lib/completions/endpoints/base.rb | 10 +- lib/completions/endpoints/hugging_face.rb | 48 +++--- lib/completions/endpoints/ollama.rb | 89 +++++++++++ lib/completions/endpoints/open_ai.rb | 4 - lib/completions/endpoints/vllm.rb | 36 ++--- lib/completions/llm.rb | 22 +-- lib/summarization/entry_point.rb | 8 - spec/lib/completions/dialects/dialect_spec.rb | 41 ----- .../dialects/llama2_classic_spec.rb | 62 -------- spec/lib/completions/dialects/mixtral_spec.rb | 66 -------- .../completions/dialects/orca_style_spec.rb | 71 --------- .../endpoints/hugging_face_spec.rb | 57 ++++--- spec/lib/completions/endpoints/vllm_spec.rb | 24 +-- spec/lib/completions/llm_spec.rb | 4 +- 32 files changed, 665 insertions(+), 845 deletions(-) delete mode 100644 lib/completions/dialects/llama2_classic.rb create mode 100644 lib/completions/dialects/mistral.rb delete mode 100644 lib/completions/dialects/mixtral.rb create mode 100644 lib/completions/dialects/open_ai_tools.rb delete mode 100644 lib/completions/dialects/orca_style.rb create mode 100644 lib/completions/dialects/xml_tools.rb create mode 100644 lib/completions/endpoints/ollama.rb delete mode 100644 spec/lib/completions/dialects/llama2_classic_spec.rb delete mode 100644 spec/lib/completions/dialects/mixtral_spec.rb delete mode 100644 spec/lib/completions/dialects/orca_style_spec.rb diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 7001bc08..f426925a 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -11,6 +11,7 @@ class AiApiAuditLog < ActiveRecord::Base Gemini = 4 Vllm = 5 Cohere = 6 + Ollama = 7 end end diff --git a/config/settings.yml b/config/settings.yml index 0dbcf624..cff41ce5 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -185,6 +185,9 @@ discourse_ai: ai_strict_token_counting: default: false hidden: true + ai_ollama_endpoint: + hidden: true + default: "" composer_ai_helper_enabled: default: false diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index d935b336..3e3c932e 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -164,12 +164,15 @@ module DiscourseAi when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID "open_ai:gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( - "mistralai/Mixtral-8x7B-Instruct-v0.1", - ) - "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" + mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) + "vllm:#{mixtral_model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( + mixtral_model, + ) + "hugging_face:#{mixtral_model}" else - "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" + "ollama:mistral" end when DiscourseAi::AiBot::EntryPoint::GEMINI_ID "google:gemini-pro" diff --git a/lib/automation.rb b/lib/automation.rb index d1604fa6..b755f1db 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -40,8 +40,10 @@ module DiscourseAi if model.start_with?("mistral") if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model) return "vllm:#{model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model) + "hugging_face:#{model}" else - return "hugging_face:#{model}" + "ollama:mistral" end end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 7368deff..9383019c 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -6,14 +6,7 @@ module DiscourseAi class ChatGpt < Dialect class << self def can_translate?(model_name) - %w[ - gpt-3.5-turbo - gpt-4 - gpt-3.5-turbo-16k - gpt-4-32k - gpt-4-turbo - gpt-4-vision-preview - ].include?(model_name) + model_name.starts_with?("gpt-") end def tokenizer @@ -23,72 +16,17 @@ module DiscourseAi VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + def native_tool_support? + true + end + def translate - messages = prompt.messages - - # ChatGPT doesn't use an assistant msg to improve long-context responses. - if messages.last[:type] == :model - messages = messages.dup - messages.pop - end - - trimmed_messages = trim_messages(messages) - - embed_user_ids = - trimmed_messages.any? do |m| + @embed_user_ids = + prompt.messages.any? do |m| m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) end - trimmed_messages.map do |msg| - if msg[:type] == :system - { role: "system", content: msg[:content] } - elsif msg[:type] == :model - { role: "assistant", content: msg[:content] } - elsif msg[:type] == :tool_call - call_details = JSON.parse(msg[:content], symbolize_names: true) - call_details[:arguments] = call_details[:arguments].to_json - call_details[:name] = msg[:name] - - { - role: "assistant", - content: nil, - tool_calls: [{ type: "function", function: call_details, id: msg[:id] }], - } - elsif msg[:type] == :tool - { role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] } - else - user_message = { role: "user", content: msg[:content] } - if msg[:id] - if embed_user_ids - user_message[:content] = "#{msg[:id]}: #{msg[:content]}" - else - user_message[:name] = msg[:id] - end - end - user_message[:content] = inline_images(user_message[:content], msg) - user_message - end - end - end - - def tools - prompt.tools.map do |t| - tool = t.dup - - tool[:parameters] = t[:parameters] - .to_a - .reduce({ type: "object", properties: {}, required: [] }) do |memo, p| - name = p[:name] - memo[:required] << name if p[:required] - - memo[:properties][name] = p.except(:name, :required, :item_type) - - memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] - memo - end - - { type: "function", function: tool } - end + super end def max_prompt_tokens @@ -107,6 +45,41 @@ module DiscourseAi private + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools) + end + + def system_msg(msg) + { role: "system", content: msg[:content] } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + + def user_msg(msg) + user_message = { role: "user", content: msg[:content] } + + if msg[:id] + if @embed_user_ids + user_message[:content] = "#{msg[:id]}: #{msg[:content]}" + else + user_message[:name] = msg[:id] + end + end + + user_message[:content] = inline_images(user_message[:content], msg) + user_message + end + def inline_images(content, message) if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" content = message[:content] diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index e3c93a59..9a15b293 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -27,41 +27,13 @@ module DiscourseAi end def translate - messages = prompt.messages - system_prompt = +"" + messages = super - messages = - trim_messages(messages) - .map do |msg| - case msg[:type] - when :system - system_prompt << msg[:content] - nil - when :tool_call - { role: "assistant", content: tool_call_to_xml(msg) } - when :tool - { role: "user", content: tool_result_to_xml(msg) } - when :model - { role: "assistant", content: msg[:content] } - when :user - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - content = inline_images(content, msg) - - { role: "user", content: content } - end - end - .compact - - if prompt.tools.present? - system_prompt << "\n\n" - system_prompt << build_tools_prompt - end + system_prompt = messages.shift[:content] if messages.first[:role] == "system" interleving_messages = [] - previous_message = nil + messages.each do |message| if previous_message if previous_message[:role] == "user" && message[:role] == "user" @@ -84,6 +56,29 @@ module DiscourseAi private + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def system_msg(msg) + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg + end + + def user_msg(msg) + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] + content = inline_images(content, msg) + + { role: "user", content: content } + end + def inline_images(content, message) if model_name.include?("claude-3") encoded_uploads = prompt.encoded_uploads(message) diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 8b4bf67d..f119aba8 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -19,57 +19,17 @@ module DiscourseAi VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ def translate - messages = prompt.messages + messages = super - # ChatGPT doesn't use an assistant msg to improve long-context responses. - if messages.last[:type] == :model - messages = messages.dup - messages.pop - end + system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM" - trimmed_messages = trim_messages(messages) + prompt = { preamble: +"#{system_message}" } + prompt[:chat_history] = messages if messages.present? - chat_history = [] - system_message = nil - - prompt = {} - - trimmed_messages.each do |msg| - case msg[:type] - when :system - if system_message - chat_history << { role: "SYSTEM", message: msg[:content] } - else - system_message = msg[:content] - end - when :model - chat_history << { role: "CHATBOT", message: msg[:content] } - when :tool_call - chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) } - when :tool - chat_history << { role: "USER", message: tool_result_to_xml(msg) } - when :user - user_message = { role: "USER", message: msg[:content] } - user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] - chat_history << user_message - end - end - - tools_prompt = build_tools_prompt - prompt[:preamble] = +"#{system_message}" - if tools_prompt.present? - prompt[:preamble] << "\n#{tools_prompt}" - prompt[ - :preamble - ] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it." - end - - prompt[:chat_history] = chat_history if chat_history.present? - - chat_history.reverse_each do |msg| + messages.reverse_each do |msg| if msg[:role] == "USER" prompt[:message] = msg[:message] - chat_history.delete(msg) + messages.delete(msg) break end end @@ -101,6 +61,43 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) + end + + def system_msg(msg) + cmd_msg = { role: "SYSTEM", message: msg[:content] } + + if tools_dialect.instructions.present? + cmd_msg[:message] = [ + msg[:content], + tools_dialect.instructions, + "NEVER attempt to run tools using JSON, always use XML. Lives depend on it.", + ].join("\n") + end + + cmd_msg + end + + def model_msg(msg) + { role: "CHATBOT", message: msg[:content] } + end + + def tool_call_msg(msg) + { role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) } + end + + def tool_msg(msg) + { role: "USER", message: tools_dialect.from_raw_tool(msg) } + end + + def user_msg(msg) + user_message = { role: "USER", message: msg[:content] } + user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] + + user_message + end end end end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 465884e2..865b5509 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -11,11 +11,9 @@ module DiscourseAi def dialect_for(model_name) dialects = [ - DiscourseAi::Completions::Dialects::Llama2Classic, DiscourseAi::Completions::Dialects::ChatGpt, - DiscourseAi::Completions::Dialects::OrcaStyle, DiscourseAi::Completions::Dialects::Gemini, - DiscourseAi::Completions::Dialects::Mixtral, + DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, ] @@ -32,40 +30,6 @@ module DiscourseAi def tokenizer raise NotImplemented end - - def tool_preamble(include_array_tip: true) - array_tip = - if include_array_tip - <<~TEXT - If a parameter type is an array, return an array of values. For example: - <$PARAMETER_NAME>["one","two","three"] - TEXT - else - "" - end - - <<~TEXT - In this environment you have access to a set of tools you can use to answer the user's question. - You may call them like this. - - - - $TOOL_NAME - - <$PARAMETER_NAME>$PARAMETER_VALUE - ... - - - - #{array_tip} - If you wish to call multiple function in one reply, wrap multiple - block in a single block. - - Always prefer to lead with tool calls, if you need to execute any. - Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc. - Here are the complete list of tools available: - TEXT - end end def initialize(generic_prompt, model_name, opts: {}) @@ -74,74 +38,30 @@ module DiscourseAi @opts = opts end - def translate - raise NotImplemented + VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + + def can_end_with_assistant_msg? + false end - def tool_result_to_xml(message) - (<<~TEXT).strip - - - #{message[:name] || message[:id]} - - #{message[:content]} - - - - TEXT - end - - def tool_call_to_xml(message) - parsed = JSON.parse(message[:content], symbolize_names: true) - parameters = +"" - - if parsed[:arguments] - parameters << "\n" - parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}\n" } - parameters << "\n" - end - - (<<~TEXT).strip - - - #{message[:name] || parsed[:name]} - #{parameters} - - TEXT + def native_tool_support? + false end def tools - tools = +"" + @tools ||= tools_dialect.translated_tools + end - prompt.tools.each do |function| - parameters = +"" - if function[:parameters].present? - function[:parameters].each do |parameter| - parameters << <<~PARAMETER - - #{parameter[:name]} - #{parameter[:type]} - #{parameter[:description]} - #{parameter[:required]} - PARAMETER - if parameter[:enum] - parameters << "#{parameter[:enum].join(",")}\n" - end - parameters << "\n" - end - end + def translate + messages = prompt.messages - tools << <<~TOOLS - - #{function[:name]} - #{function[:description]} - - #{parameters} - - TOOLS + # Some models use an assistant msg to improve long-context responses. + if messages.last[:type] == :model && can_end_with_assistant_msg? + messages = messages.dup + messages.pop end - tools + trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact end def conversation_context @@ -154,19 +74,6 @@ module DiscourseAi attr_reader :prompt - def build_tools_prompt - return "" if prompt.tools.blank? - - has_arrays = - prompt.tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } - - (<<~TEXT).strip - #{self.class.tool_preamble(include_array_tip: has_arrays)} - - #{tools} - TEXT - end - private attr_reader :model_name, :opts @@ -230,6 +137,30 @@ module DiscourseAi def calculate_message_token(msg) self.class.tokenizer.size(msg[:content].to_s) end + + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) + end + + def system_msg(msg) + raise NotImplemented + end + + def assistant_msg(msg) + raise NotImplemented + end + + def user_msg(msg) + raise NotImplemented + end + + def tool_call_msg(msg) + { role: "assistant", content: tools_dialect.from_raw_tool_call(msg) } + end + + def tool_msg(msg) + { role: "user", content: tools_dialect.from_raw_tool(msg) } + end end end end diff --git a/lib/completions/dialects/fake.rb b/lib/completions/dialects/fake.rb index c569ee28..898f3364 100644 --- a/lib/completions/dialects/fake.rb +++ b/lib/completions/dialects/fake.rb @@ -9,14 +9,14 @@ module DiscourseAi model_name == "fake" end - def translate - "" - end - def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end end + + def translate + "" + end end end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index a425d4f5..678dc0cd 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -14,59 +14,30 @@ module DiscourseAi end end + def native_tool_support? + true + end + def translate # Gemini complains if we don't alternate model/user roles. noop_model_response = { role: "model", parts: { text: "Ok." } } + messages = super - messages = prompt.messages + interleving_messages = [] + previous_message = nil - # Gemini doesn't use an assistant msg to improve long-context responses. - messages.pop if messages.last[:type] == :model - - memo = [] - - trim_messages(messages).each do |msg| - if msg[:type] == :system - memo << { role: "user", parts: { text: msg[:content] } } - memo << noop_model_response.dup - elsif msg[:type] == :model - memo << { role: "model", parts: { text: msg[:content] } } - elsif msg[:type] == :tool_call - call_details = JSON.parse(msg[:content], symbolize_names: true) - - memo << { - role: "model", - parts: { - functionCall: { - name: msg[:name] || call_details[:name], - args: call_details[:arguments], - }, - }, - } - elsif msg[:type] == :tool - memo << { - role: "function", - parts: { - functionResponse: { - name: msg[:name] || msg[:id], - response: { - content: msg[:content], - }, - }, - }, - } - else - # Gemini quirk. Doesn't accept tool -> user or user -> user msgs. - previous_msg_role = memo.last&.dig(:role) - if previous_msg_role == "user" || previous_msg_role == "function" - memo << noop_model_response.dup + messages.each do |message| + if previous_message + if (previous_message[:role] == "user" || previous_message[:role] == "function") && + message[:role] == "user" + interleving_messages << noop_model_response.dup end - - memo << { role: "user", parts: { text: msg[:content] } } end + interleving_messages << message + previous_message = message end - memo + interleving_messages end def tools @@ -110,6 +81,46 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + def system_msg(msg) + { role: "user", parts: { text: msg[:content] } } + end + + def model_msg(msg) + { role: "model", parts: { text: msg[:content] } } + end + + def user_msg(msg) + { role: "user", parts: { text: msg[:content] } } + end + + def tool_call_msg(msg) + call_details = JSON.parse(msg[:content], symbolize_names: true) + + { + role: "model", + parts: { + functionCall: { + name: msg[:name] || call_details[:name], + args: call_details[:arguments], + }, + }, + } + end + + def tool_msg(msg) + { + role: "function", + parts: { + functionResponse: { + name: msg[:name] || msg[:id], + response: { + content: msg[:content], + }, + }, + }, + } + end end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb deleted file mode 100644 index 3b5675f1..00000000 --- a/lib/completions/dialects/llama2_classic.rb +++ /dev/null @@ -1,68 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Llama2Classic < Dialect - class << self - def can_translate?(model_name) - %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) - end - - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - - def translate - messages = prompt.messages - - llama2_prompt = - trim_messages(messages).reduce(+"") do |memo, msg| - next(memo) if msg[:type] == :tool_call - - if msg[:type] == :system - memo << (<<~TEXT).strip - [INST] - <> - #{msg[:content]} - #{build_tools_prompt} - <> - [/INST] - TEXT - elsif msg[:type] == :model - memo << "\n#{msg[:content]}" - elsif msg[:type] == :tool - JSON.parse(msg[:content], symbolize_names: true) - memo << "\n[INST]\n" - - memo << (<<~TEXT).strip - - - #{msg[:id]} - - #{msg[:content]} - - - - [/INST] - TEXT - else - memo << "\n[INST]#{msg[:content]}[/INST]" - end - - memo - end - - llama2_prompt << "\n" if llama2_prompt.ends_with?("[/INST]") - - llama2_prompt - end - - def max_prompt_tokens - SiteSetting.ai_hugging_face_token_limit - end - end - end - end -end diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb new file mode 100644 index 00000000..7752a876 --- /dev/null +++ b/lib/completions/dialects/mistral.rb @@ -0,0 +1,57 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Mistral < Dialect + class << self + def can_translate?(model_name) + %w[ + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + mistral + ].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::MixtralTokenizer + end + end + + def tools + @tools ||= tools_dialect.translated_tools + end + + def max_prompt_tokens + 32_000 + end + + private + + def system_msg(msg) + { role: "assistant", content: "#{msg[:content]}" } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + + def user_msg(msg) + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] + + { role: "user", content: content } + end + end + end + end +end diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb deleted file mode 100644 index 425d741e..00000000 --- a/lib/completions/dialects/mixtral.rb +++ /dev/null @@ -1,57 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Mixtral < Dialect - class << self - def can_translate?(model_name) - %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - end - - def tokenizer - DiscourseAi::Tokenizer::MixtralTokenizer - end - end - - def translate - messages = prompt.messages - - mixtral_prompt = - trim_messages(messages).reduce(+"") do |memo, msg| - if msg[:type] == :tool_call - memo << "\n" - memo << tool_call_to_xml(msg) - elsif msg[:type] == :system - memo << (<<~TEXT).strip - [INST] - #{msg[:content]} - #{build_tools_prompt} - [/INST] Ok - TEXT - elsif msg[:type] == :model - memo << "\n#{msg[:content]}" - elsif msg[:type] == :tool - memo << "\n" - memo << tool_result_to_xml(msg) - else - memo << "\n[INST]#{msg[:content]}[/INST]" - end - - memo - end - - mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]") - - mixtral_prompt - end - - def max_prompt_tokens - 32_000 - end - end - end - end -end diff --git a/lib/completions/dialects/open_ai_tools.rb b/lib/completions/dialects/open_ai_tools.rb new file mode 100644 index 00000000..b990e379 --- /dev/null +++ b/lib/completions/dialects/open_ai_tools.rb @@ -0,0 +1,62 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class OpenAiTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + raw_tools.map do |t| + tool = t.dup + + tool[:parameters] = t[:parameters] + .to_a + .reduce({ type: "object", properties: {}, required: [] }) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] + + memo[:properties][name] = p.except(:name, :required, :item_type) + + memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] + memo + end + + { type: "function", function: tool } + end + end + + def instructions + "" # Noop. Tools are listed separate. + end + + def from_raw_tool_call(raw_message) + call_details = JSON.parse(raw_message[:content], symbolize_names: true) + call_details[:arguments] = call_details[:arguments].to_json + call_details[:name] = raw_message[:name] + + { + role: "assistant", + content: nil, + tool_calls: [{ type: "function", function: call_details, id: raw_message[:id] }], + } + end + + def from_raw_tool(raw_message) + { + role: "tool", + tool_call_id: raw_message[:id], + content: raw_message[:content], + name: raw_message[:name], + } + end + + private + + attr_reader :raw_tools + end + end + end +end diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb deleted file mode 100644 index 4ec42a36..00000000 --- a/lib/completions/dialects/orca_style.rb +++ /dev/null @@ -1,59 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class OrcaStyle < Dialect - class << self - def can_translate?(model_name) - %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) - end - - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - - def translate - messages = prompt.messages - trimmed_messages = trim_messages(messages) - - # Need to include this differently - last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil - - llama2_prompt = - trimmed_messages.reduce(+"") do |memo, msg| - if msg[:type] == :tool_call - memo << "\n### Assistant:\n" - memo << tool_call_to_xml(msg) - elsif msg[:type] == :system - memo << (<<~TEXT).strip - ### System: - #{msg[:content]} - #{build_tools_prompt} - TEXT - elsif msg[:type] == :model - memo << "\n### Assistant:\n#{msg[:content]}" - elsif msg[:type] == :tool - memo << "\n### User:\n" - memo << tool_result_to_xml(msg) - else - memo << "\n### User:\n#{msg[:content]}" - end - - memo - end - - llama2_prompt << "\n### Assistant:\n" - llama2_prompt << "#{last_message[:content]}:" if last_message - - llama2_prompt - end - - def max_prompt_tokens - SiteSetting.ai_hugging_face_token_limit - end - end - end - end -end diff --git a/lib/completions/dialects/xml_tools.rb b/lib/completions/dialects/xml_tools.rb new file mode 100644 index 00000000..47988a71 --- /dev/null +++ b/lib/completions/dialects/xml_tools.rb @@ -0,0 +1,125 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class XmlTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + raw_tools.reduce(+"") do |tools, function| + parameters = +"" + if function[:parameters].present? + function[:parameters].each do |parameter| + parameters << <<~PARAMETER + + #{parameter[:name]} + #{parameter[:type]} + #{parameter[:description]} + #{parameter[:required]} + PARAMETER + if parameter[:enum] + parameters << "#{parameter[:enum].join(",")}\n" + end + parameters << "\n" + end + end + + tools << <<~TOOLS + + #{function[:name]} + #{function[:description]} + + #{parameters} + + TOOLS + end + end + + def instructions + return "" if raw_tools.blank? + + has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } + + (<<~TEXT).strip + #{tool_preamble(include_array_tip: has_arrays)} + + #{translated_tools} + TEXT + end + + def from_raw_tool(raw_message) + (<<~TEXT).strip + + + #{raw_message[:name] || raw_message[:id]} + + #{raw_message[:content]} + + + + TEXT + end + + def from_raw_tool_call(raw_message) + parsed = JSON.parse(raw_message[:content], symbolize_names: true) + parameters = +"" + + if parsed[:arguments] + parameters << "\n" + parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}\n" } + parameters << "\n" + end + + (<<~TEXT).strip + + + #{raw_message[:name] || parsed[:name]} + #{parameters} + + TEXT + end + + private + + attr_reader :raw_tools + + def tool_preamble(include_array_tip: true) + array_tip = + if include_array_tip + <<~TEXT + If a parameter type is an array, return an array of values. For example: + <$PARAMETER_NAME>["one","two","three"] + TEXT + else + "" + end + + <<~TEXT + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. + + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + #{array_tip} + If you wish to call multiple function in one reply, wrap multiple + block in a single block. + + Always prefer to lead with tool calls, if you need to execute any. + Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc. + Here are the complete list of tools available: + TEXT + end + end + end + end +end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 8c27a269..ee9d4f17 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -62,7 +62,7 @@ module DiscourseAi # this is an approximation, we will update it later if request goes through def prompt_size(prompt) - super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) + tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) end def model_uri diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 5f17cad1..b7f3e8bf 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -51,7 +51,7 @@ module DiscourseAi def prompt_size(prompt) # approximation - super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) + tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) end def model_uri diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 0766a9a4..7a914a1e 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -19,6 +19,8 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Cohere, ] + endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development? + if Rails.env.test? || Rails.env.development? endpoints << DiscourseAi::Completions::Endpoints::Fake end @@ -67,6 +69,10 @@ module DiscourseAi false end + def use_ssl? + true + end + def perform_completion!(dialect, user, model_params = {}, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) @@ -78,7 +84,7 @@ module DiscourseAi FinalDestination::HTTP.start( model_uri.host, model_uri.port, - use_ssl: true, + use_ssl: use_ssl?, read_timeout: TIMEOUT, open_timeout: TIMEOUT, write_timeout: TIMEOUT, @@ -315,7 +321,7 @@ module DiscourseAi end def extract_prompt_for_tokenizer(prompt) - prompt + prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end def build_buffer diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index 5542c73f..d6237c05 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -8,14 +8,9 @@ module DiscourseAi def can_contact?(endpoint_name, model_name) return false unless endpoint_name == "hugging_face" - %w[ - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - ].include?(model_name) + %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( + model_name, + ) end def dependant_setting_names @@ -31,24 +26,21 @@ module DiscourseAi end end - def default_options - { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } - end - def normalize_model_params(model_params) model_params = model_params.dup + # max_tokens, temperature are already supported if model_params[:stop_sequences] model_params[:stop] = model_params.delete(:stop_sequences) end - if model_params[:max_tokens] - model_params[:max_new_tokens] = model_params.delete(:max_tokens) - end - model_params end + def default_options + { model: model, temperature: 0.7 } + end + def provider_id AiApiAuditLog::Provider::HuggingFaceTextGeneration end @@ -61,13 +53,14 @@ module DiscourseAi def prepare_payload(prompt, model_params, _dialect) default_options - .merge(inputs: prompt) + .merge(model_params) + .merge(messages: prompt) .tap do |payload| - payload[:parameters].merge!(model_params) + if !payload[:max_tokens] + token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 - token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 - - payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) + payload[:max_tokens] = token_limit - prompt_size(prompt) + end payload[:stream] = true if @streaming_mode end @@ -85,16 +78,13 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed - if @streaming_mode - # Last chunk contains full response, which we already yielded. - return if parsed.dig(:token, :special) + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - parsed.dig(:token, :text).to_s - else - parsed[0][:generated_text].to_s - end + response_h.dig(:content) end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb new file mode 100644 index 00000000..0fd748d4 --- /dev/null +++ b/lib/completions/endpoints/ollama.rb @@ -0,0 +1,89 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Ollama < Base + class << self + def can_contact?(endpoint_name, model_name) + endpoint_name == "ollama" && %w[mistral].include?(model_name) + end + + def dependant_setting_names + %w[ai_ollama_endpoint] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_ollama_endpoint.present? + end + + def endpoint_name(model_name) + "Ollama - #{model_name}" + end + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + # max_tokens, temperature are already supported + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + model_params + end + + def default_options + { max_tokens: 2000, model: model } + end + + def provider_id + AiApiAuditLog::Provider::Ollama + end + + def use_ssl? + false + end + + private + + def model_uri + URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions") + end + + def prepare_payload(prompt, model_params, _dialect) + default_options + .merge(model_params) + .merge(messages: prompt) + .tap { |payload| payload[:stream] = true if @streaming_mode } + end + + def prepare_request(payload) + headers = { "Content-Type" => "application/json" } + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def partials_from(decoded_chunk) + decoded_chunk + .split("\n") + .map do |line| + data = line.split("data: ", 2)[1] + data == "[DONE]" ? nil : data + end + .compact + end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed + + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + response_h.dig(:content) + end + end + end + end +end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 16e0b886..2ccd817e 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -153,10 +153,6 @@ module DiscourseAi .compact end - def extract_prompt_for_tokenizer(prompt) - prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") - end - def has_tool?(_response_data) @has_function_call end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 310bcdc9..7db1452d 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -7,14 +7,9 @@ module DiscourseAi class << self def can_contact?(endpoint_name, model_name) endpoint_name == "vllm" && - %w[ - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - ].include?(model_name) + %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( + model_name, + ) end def dependant_setting_names @@ -54,9 +49,9 @@ module DiscourseAi def model_uri service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv) if service.present? - api_endpoint = "https://#{service.target}:#{service.port}/v1/completions" + api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" else - api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions" + api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions" end @uri ||= URI(api_endpoint) end @@ -64,7 +59,7 @@ module DiscourseAi def prepare_payload(prompt, model_params, _dialect) default_options .merge(model_params) - .merge(prompt: prompt) + .merge(messages: prompt) .tap { |payload| payload[:stream] = true if @streaming_mode } end @@ -76,15 +71,6 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - - # half a line sent here - return if !parsed - - parsed.dig(:text) - end - def partials_from(decoded_chunk) decoded_chunk .split("\n") @@ -94,6 +80,16 @@ module DiscourseAi end .compact end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed + + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + response_h.dig(:content) + end end end end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index c54c0dd4..47190644 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -31,21 +31,10 @@ module DiscourseAi claude-3-opus ], anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus], - vllm: %w[ - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - ], + vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2], hugging_face: %w[ mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf ], cohere: %w[command-light command command-r command-r-plus], open_ai: %w[ @@ -57,7 +46,10 @@ module DiscourseAi gpt-4-vision-preview ], google: %w[gemini-pro gemini-1.5-pro], - }.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? } + }.tap do |h| + h[:ollama] = ["mistral"] if Rails.env.development? + h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? + end end def valid_provider_models @@ -120,8 +112,6 @@ module DiscourseAi @gateway = gateway end - delegate :tokenizer, to: :dialect_klass - # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param user { User } - User requesting the summary. # @@ -184,6 +174,8 @@ module DiscourseAi dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens end + delegate :tokenizer, to: :dialect_klass + attr_reader :model_name private diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 9e60eb67..37c72725 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -10,14 +10,6 @@ module DiscourseAi Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000), Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), - Models::Llama2.new( - "hugging_face:Llama2-chat-hf", - max_tokens: SiteSetting.ai_hugging_face_token_limit, - ), - Models::Llama2FineTunedOrcaStyle.new( - "hugging_face:StableBeluga2", - max_tokens: SiteSetting.ai_hugging_face_token_limit, - ), Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000), ] diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb index 2410f415..c54d1838 100644 --- a/spec/lib/completions/dialects/dialect_spec.rb +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -17,47 +17,6 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect end RSpec.describe DiscourseAi::Completions::Dialects::Dialect do - describe "#build_tools_prompt" do - it "can exclude array instructions" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [ - { - name: "weather", - description: "lookup weather in a city", - parameters: [{ name: "city", type: "string", description: "city name", required: true }], - }, - ] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).not_to include("array") - end - - it "can include array instructions" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [ - { - name: "weather", - description: "lookup weather in a city", - parameters: [{ name: "city", type: "array", description: "city names", required: true }], - }, - ] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).to include("array") - end - - it "does not break if there are no params" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [{ name: "categories", description: "lookup all categories" }] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).not_to include("array") - end - end - describe "#trim_messages" do it "should trim tool messages if tool_calls are trimmed" do prompt = DiscourseAi::Completions::Prompt.new("12345") diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb deleted file mode 100644 index 0242ebf6..00000000 --- a/spec/lib/completions/dialects/llama2_classic_spec.rb +++ /dev/null @@ -1,62 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do - let(:model_name) { "Llama2-chat-hf" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - [INST] - <> - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - <> - [/INST] - [INST]#{context.simple_user_input}[/INST] - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT) - [INST] - <> - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - <> - [/INST] - [INST]This is a message by a user[/INST] - I'm a previous bot reply, that's why there's no user - [INST]This is a new message by a user[/INST] - [INST] - - - tool_id - - "I'm a tool result" - - - - [/INST] - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - translated = context.long_user_input_scenario - - expect(translated.length).to be < context.long_message_text.length - end - end -end diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb deleted file mode 100644 index 499dad73..00000000 --- a/spec/lib/completions/dialects/mixtral_spec.rb +++ /dev/null @@ -1,66 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do - let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - [INST] - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - [/INST] Ok - [INST]#{context.simple_user_input}[/INST] - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT).strip - [INST] - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - [/INST] Ok - [INST]This is a message by a user[/INST] - I'm a previous bot reply, that's why there's no user - [INST]This is a new message by a user[/INST] - - - get_weather - - Sydney - c - - - - - - get_weather - - "I'm a tool result" - - - - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - length = 6_000 - translated = context.long_user_input_scenario(length: length) - - expect(translated.length).to be < context.long_message_text(length: length).length - end - end -end diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb deleted file mode 100644 index 63b414b8..00000000 --- a/spec/lib/completions/dialects/orca_style_spec.rb +++ /dev/null @@ -1,71 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do - let(:model_name) { "StableBeluga2" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - ### System: - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - ### User: - #{context.simple_user_input} - ### Assistant: - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT) - ### System: - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - ### User: - This is a message by a user - ### Assistant: - I'm a previous bot reply, that's why there's no user - ### User: - This is a new message by a user - ### Assistant: - - - get_weather - - Sydney - c - - - - ### User: - - - get_weather - - "I'm a tool result" - - - - ### Assistant: - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - translated = context.long_user_input_scenario - - expect(translated.length).to be < context.long_message_text.length - end - end -end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index bfdfa74e..5b9bd9f5 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -4,7 +4,20 @@ require_relative "endpoint_compliance" class HuggingFaceMock < EndpointMock def response(content) - [{ generated_text: content }] + { + id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", + object: "chat.completion", + created: 1_678_464_820, + model: "Llama2-*-chat-hf", + usage: { + prompt_tokens: 337, + completion_tokens: 162, + total_tokens: 499, + }, + choices: [ + { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + ], + } end def stub_response(prompt, response_text, tool_call: false) @@ -14,26 +27,32 @@ class HuggingFaceMock < EndpointMock .to_return(status: 200, body: JSON.dump(response(response_text))) end - def stream_line(delta, deltas, finish_reason: nil) + def stream_line(delta, finish_reason: nil) +"data: " << { - token: { - id: 29_889, - text: delta, - logprob: -0.08319092, - special: !!finish_reason, - }, - generated_text: finish_reason ? deltas.join : nil, - details: nil, + id: "chatcmpl-#{SecureRandom.hex}", + object: "chat.completion.chunk", + created: 1_681_283_881, + model: "Llama2-*-chat-hf", + choices: [{ delta: { content: delta } }], + finish_reason: finish_reason, + index: 0, }.to_json end + def stub_raw(chunks) + WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return( + status: 200, + body: chunks, + ) + end + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], deltas, finish_reason: true) + stream_line(deltas[index], finish_reason: "stop_sequence") else - stream_line(deltas[index], deltas) + stream_line(deltas[index]) end end @@ -43,16 +62,18 @@ class HuggingFaceMock < EndpointMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .with(body: request_body(prompt, stream: true)) .to_return(status: 200, body: chunks) + + yield if block_given? end - def request_body(prompt, stream: false) + def request_body(prompt, stream: false, tool_call: false) model .default_options - .merge(inputs: prompt) - .tap do |payload| - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - + .merge(messages: prompt) + .tap do |b| + b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) - payload[:stream] = true if stream + b[:stream] = true if stream end .to_json end @@ -70,7 +91,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do let(:hf_mock) { HuggingFaceMock.new(endpoint) } let(:compliance) do - EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user) + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) end describe "#perform_completion!" do diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 52d87007..d879ad09 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -6,7 +6,7 @@ class VllmMock < EndpointMock def response(content) { id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", - object: "text_completion", + object: "chat.completion", created: 1_678_464_820, model: "mistralai/Mixtral-8x7B-Instruct-v0.1", usage: { @@ -14,14 +14,16 @@ class VllmMock < EndpointMock completion_tokens: 162, total_tokens: 499, }, - choices: [{ text: content, finish_reason: "stop", index: 0 }], + choices: [ + { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + ], } end def stub_response(prompt, response_text, tool_call: false) WebMock - .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: model.default_options.merge(prompt: prompt).to_json) + .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") + .with(body: model.default_options.merge(messages: prompt).to_json) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -30,7 +32,7 @@ class VllmMock < EndpointMock id: "cmpl-#{SecureRandom.hex}", created: 1_681_283_881, model: "mistralai/Mixtral-8x7B-Instruct-v0.1", - choices: [{ text: delta, finish_reason: finish_reason, index: 0 }], + choices: [{ delta: { content: delta } }], index: 0, }.to_json end @@ -48,8 +50,8 @@ class VllmMock < EndpointMock chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock - .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) + .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") + .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) .to_return(status: 200, body: chunks) end end @@ -67,14 +69,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do let(:anthropic_mock) { VllmMock.new(endpoint) } let(:compliance) do - EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user) + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) end - let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) } + let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) } let(:prompt) { dialect.translate } - let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } + let(:request_body) { model.default_options.merge(messages: prompt).to_json } + let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 3aeeb04a..e91a4a60 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -3,8 +3,8 @@ RSpec.describe DiscourseAi::Completions::Llm do subject(:llm) do described_class.new( - DiscourseAi::Completions::Dialects::OrcaStyle, - nil, + DiscourseAi::Completions::Dialects::Mistral, + canned_response, "hugging_face:Upstage-Llama-2-*-instruct-v2", gateway: canned_response, )