diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 7001bc08..f426925a 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -11,6 +11,7 @@ class AiApiAuditLog < ActiveRecord::Base Gemini = 4 Vllm = 5 Cohere = 6 + Ollama = 7 end end diff --git a/config/settings.yml b/config/settings.yml index 0dbcf624..cff41ce5 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -185,6 +185,9 @@ discourse_ai: ai_strict_token_counting: default: false hidden: true + ai_ollama_endpoint: + hidden: true + default: "" composer_ai_helper_enabled: default: false diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index d935b336..3e3c932e 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -164,12 +164,15 @@ module DiscourseAi when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID "open_ai:gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( - "mistralai/Mixtral-8x7B-Instruct-v0.1", - ) - "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" + mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) + "vllm:#{mixtral_model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( + mixtral_model, + ) + "hugging_face:#{mixtral_model}" else - "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" + "ollama:mistral" end when DiscourseAi::AiBot::EntryPoint::GEMINI_ID "google:gemini-pro" diff --git a/lib/automation.rb b/lib/automation.rb index d1604fa6..b755f1db 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -40,8 +40,10 @@ module DiscourseAi if model.start_with?("mistral") if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model) return "vllm:#{model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model) + "hugging_face:#{model}" else - return "hugging_face:#{model}" + "ollama:mistral" end end diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 7368deff..9383019c 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -6,14 +6,7 @@ module DiscourseAi class ChatGpt < Dialect class << self def can_translate?(model_name) - %w[ - gpt-3.5-turbo - gpt-4 - gpt-3.5-turbo-16k - gpt-4-32k - gpt-4-turbo - gpt-4-vision-preview - ].include?(model_name) + model_name.starts_with?("gpt-") end def tokenizer @@ -23,72 +16,17 @@ module DiscourseAi VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + def native_tool_support? + true + end + def translate - messages = prompt.messages - - # ChatGPT doesn't use an assistant msg to improve long-context responses. - if messages.last[:type] == :model - messages = messages.dup - messages.pop - end - - trimmed_messages = trim_messages(messages) - - embed_user_ids = - trimmed_messages.any? do |m| + @embed_user_ids = + prompt.messages.any? do |m| m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) end - trimmed_messages.map do |msg| - if msg[:type] == :system - { role: "system", content: msg[:content] } - elsif msg[:type] == :model - { role: "assistant", content: msg[:content] } - elsif msg[:type] == :tool_call - call_details = JSON.parse(msg[:content], symbolize_names: true) - call_details[:arguments] = call_details[:arguments].to_json - call_details[:name] = msg[:name] - - { - role: "assistant", - content: nil, - tool_calls: [{ type: "function", function: call_details, id: msg[:id] }], - } - elsif msg[:type] == :tool - { role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] } - else - user_message = { role: "user", content: msg[:content] } - if msg[:id] - if embed_user_ids - user_message[:content] = "#{msg[:id]}: #{msg[:content]}" - else - user_message[:name] = msg[:id] - end - end - user_message[:content] = inline_images(user_message[:content], msg) - user_message - end - end - end - - def tools - prompt.tools.map do |t| - tool = t.dup - - tool[:parameters] = t[:parameters] - .to_a - .reduce({ type: "object", properties: {}, required: [] }) do |memo, p| - name = p[:name] - memo[:required] << name if p[:required] - - memo[:properties][name] = p.except(:name, :required, :item_type) - - memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] - memo - end - - { type: "function", function: tool } - end + super end def max_prompt_tokens @@ -107,6 +45,41 @@ module DiscourseAi private + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools) + end + + def system_msg(msg) + { role: "system", content: msg[:content] } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + + def user_msg(msg) + user_message = { role: "user", content: msg[:content] } + + if msg[:id] + if @embed_user_ids + user_message[:content] = "#{msg[:id]}: #{msg[:content]}" + else + user_message[:name] = msg[:id] + end + end + + user_message[:content] = inline_images(user_message[:content], msg) + user_message + end + def inline_images(content, message) if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" content = message[:content] diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index e3c93a59..9a15b293 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -27,41 +27,13 @@ module DiscourseAi end def translate - messages = prompt.messages - system_prompt = +"" + messages = super - messages = - trim_messages(messages) - .map do |msg| - case msg[:type] - when :system - system_prompt << msg[:content] - nil - when :tool_call - { role: "assistant", content: tool_call_to_xml(msg) } - when :tool - { role: "user", content: tool_result_to_xml(msg) } - when :model - { role: "assistant", content: msg[:content] } - when :user - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - content = inline_images(content, msg) - - { role: "user", content: content } - end - end - .compact - - if prompt.tools.present? - system_prompt << "\n\n" - system_prompt << build_tools_prompt - end + system_prompt = messages.shift[:content] if messages.first[:role] == "system" interleving_messages = [] - previous_message = nil + messages.each do |message| if previous_message if previous_message[:role] == "user" && message[:role] == "user" @@ -84,6 +56,29 @@ module DiscourseAi private + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def system_msg(msg) + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg + end + + def user_msg(msg) + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] + content = inline_images(content, msg) + + { role: "user", content: content } + end + def inline_images(content, message) if model_name.include?("claude-3") encoded_uploads = prompt.encoded_uploads(message) diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 8b4bf67d..f119aba8 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -19,57 +19,17 @@ module DiscourseAi VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ def translate - messages = prompt.messages + messages = super - # ChatGPT doesn't use an assistant msg to improve long-context responses. - if messages.last[:type] == :model - messages = messages.dup - messages.pop - end + system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM" - trimmed_messages = trim_messages(messages) + prompt = { preamble: +"#{system_message}" } + prompt[:chat_history] = messages if messages.present? - chat_history = [] - system_message = nil - - prompt = {} - - trimmed_messages.each do |msg| - case msg[:type] - when :system - if system_message - chat_history << { role: "SYSTEM", message: msg[:content] } - else - system_message = msg[:content] - end - when :model - chat_history << { role: "CHATBOT", message: msg[:content] } - when :tool_call - chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) } - when :tool - chat_history << { role: "USER", message: tool_result_to_xml(msg) } - when :user - user_message = { role: "USER", message: msg[:content] } - user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] - chat_history << user_message - end - end - - tools_prompt = build_tools_prompt - prompt[:preamble] = +"#{system_message}" - if tools_prompt.present? - prompt[:preamble] << "\n#{tools_prompt}" - prompt[ - :preamble - ] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it." - end - - prompt[:chat_history] = chat_history if chat_history.present? - - chat_history.reverse_each do |msg| + messages.reverse_each do |msg| if msg[:role] == "USER" prompt[:message] = msg[:message] - chat_history.delete(msg) + messages.delete(msg) break end end @@ -101,6 +61,43 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) + end + + def system_msg(msg) + cmd_msg = { role: "SYSTEM", message: msg[:content] } + + if tools_dialect.instructions.present? + cmd_msg[:message] = [ + msg[:content], + tools_dialect.instructions, + "NEVER attempt to run tools using JSON, always use XML. Lives depend on it.", + ].join("\n") + end + + cmd_msg + end + + def model_msg(msg) + { role: "CHATBOT", message: msg[:content] } + end + + def tool_call_msg(msg) + { role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) } + end + + def tool_msg(msg) + { role: "USER", message: tools_dialect.from_raw_tool(msg) } + end + + def user_msg(msg) + user_message = { role: "USER", message: msg[:content] } + user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] + + user_message + end end end end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 465884e2..865b5509 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -11,11 +11,9 @@ module DiscourseAi def dialect_for(model_name) dialects = [ - DiscourseAi::Completions::Dialects::Llama2Classic, DiscourseAi::Completions::Dialects::ChatGpt, - DiscourseAi::Completions::Dialects::OrcaStyle, DiscourseAi::Completions::Dialects::Gemini, - DiscourseAi::Completions::Dialects::Mixtral, + DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, ] @@ -32,40 +30,6 @@ module DiscourseAi def tokenizer raise NotImplemented end - - def tool_preamble(include_array_tip: true) - array_tip = - if include_array_tip - <<~TEXT - If a parameter type is an array, return an array of values. For example: - <$PARAMETER_NAME>["one","two","three"] - TEXT - else - "" - end - - <<~TEXT - In this environment you have access to a set of tools you can use to answer the user's question. - You may call them like this. - - - - $TOOL_NAME - - <$PARAMETER_NAME>$PARAMETER_VALUE - ... - - - - #{array_tip} - If you wish to call multiple function in one reply, wrap multiple - block in a single block. - - Always prefer to lead with tool calls, if you need to execute any. - Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc. - Here are the complete list of tools available: - TEXT - end end def initialize(generic_prompt, model_name, opts: {}) @@ -74,74 +38,30 @@ module DiscourseAi @opts = opts end - def translate - raise NotImplemented + VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + + def can_end_with_assistant_msg? + false end - def tool_result_to_xml(message) - (<<~TEXT).strip - - - #{message[:name] || message[:id]} - - #{message[:content]} - - - - TEXT - end - - def tool_call_to_xml(message) - parsed = JSON.parse(message[:content], symbolize_names: true) - parameters = +"" - - if parsed[:arguments] - parameters << "\n" - parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}\n" } - parameters << "\n" - end - - (<<~TEXT).strip - - - #{message[:name] || parsed[:name]} - #{parameters} - - TEXT + def native_tool_support? + false end def tools - tools = +"" + @tools ||= tools_dialect.translated_tools + end - prompt.tools.each do |function| - parameters = +"" - if function[:parameters].present? - function[:parameters].each do |parameter| - parameters << <<~PARAMETER - - #{parameter[:name]} - #{parameter[:type]} - #{parameter[:description]} - #{parameter[:required]} - PARAMETER - if parameter[:enum] - parameters << "#{parameter[:enum].join(",")}\n" - end - parameters << "\n" - end - end + def translate + messages = prompt.messages - tools << <<~TOOLS - - #{function[:name]} - #{function[:description]} - - #{parameters} - - TOOLS + # Some models use an assistant msg to improve long-context responses. + if messages.last[:type] == :model && can_end_with_assistant_msg? + messages = messages.dup + messages.pop end - tools + trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact end def conversation_context @@ -154,19 +74,6 @@ module DiscourseAi attr_reader :prompt - def build_tools_prompt - return "" if prompt.tools.blank? - - has_arrays = - prompt.tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } - - (<<~TEXT).strip - #{self.class.tool_preamble(include_array_tip: has_arrays)} - - #{tools} - TEXT - end - private attr_reader :model_name, :opts @@ -230,6 +137,30 @@ module DiscourseAi def calculate_message_token(msg) self.class.tokenizer.size(msg[:content].to_s) end + + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) + end + + def system_msg(msg) + raise NotImplemented + end + + def assistant_msg(msg) + raise NotImplemented + end + + def user_msg(msg) + raise NotImplemented + end + + def tool_call_msg(msg) + { role: "assistant", content: tools_dialect.from_raw_tool_call(msg) } + end + + def tool_msg(msg) + { role: "user", content: tools_dialect.from_raw_tool(msg) } + end end end end diff --git a/lib/completions/dialects/fake.rb b/lib/completions/dialects/fake.rb index c569ee28..898f3364 100644 --- a/lib/completions/dialects/fake.rb +++ b/lib/completions/dialects/fake.rb @@ -9,14 +9,14 @@ module DiscourseAi model_name == "fake" end - def translate - "" - end - def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end end + + def translate + "" + end end end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index a425d4f5..678dc0cd 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -14,59 +14,30 @@ module DiscourseAi end end + def native_tool_support? + true + end + def translate # Gemini complains if we don't alternate model/user roles. noop_model_response = { role: "model", parts: { text: "Ok." } } + messages = super - messages = prompt.messages + interleving_messages = [] + previous_message = nil - # Gemini doesn't use an assistant msg to improve long-context responses. - messages.pop if messages.last[:type] == :model - - memo = [] - - trim_messages(messages).each do |msg| - if msg[:type] == :system - memo << { role: "user", parts: { text: msg[:content] } } - memo << noop_model_response.dup - elsif msg[:type] == :model - memo << { role: "model", parts: { text: msg[:content] } } - elsif msg[:type] == :tool_call - call_details = JSON.parse(msg[:content], symbolize_names: true) - - memo << { - role: "model", - parts: { - functionCall: { - name: msg[:name] || call_details[:name], - args: call_details[:arguments], - }, - }, - } - elsif msg[:type] == :tool - memo << { - role: "function", - parts: { - functionResponse: { - name: msg[:name] || msg[:id], - response: { - content: msg[:content], - }, - }, - }, - } - else - # Gemini quirk. Doesn't accept tool -> user or user -> user msgs. - previous_msg_role = memo.last&.dig(:role) - if previous_msg_role == "user" || previous_msg_role == "function" - memo << noop_model_response.dup + messages.each do |message| + if previous_message + if (previous_message[:role] == "user" || previous_message[:role] == "function") && + message[:role] == "user" + interleving_messages << noop_model_response.dup end - - memo << { role: "user", parts: { text: msg[:content] } } end + interleving_messages << message + previous_message = message end - memo + interleving_messages end def tools @@ -110,6 +81,46 @@ module DiscourseAi def calculate_message_token(context) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end + + def system_msg(msg) + { role: "user", parts: { text: msg[:content] } } + end + + def model_msg(msg) + { role: "model", parts: { text: msg[:content] } } + end + + def user_msg(msg) + { role: "user", parts: { text: msg[:content] } } + end + + def tool_call_msg(msg) + call_details = JSON.parse(msg[:content], symbolize_names: true) + + { + role: "model", + parts: { + functionCall: { + name: msg[:name] || call_details[:name], + args: call_details[:arguments], + }, + }, + } + end + + def tool_msg(msg) + { + role: "function", + parts: { + functionResponse: { + name: msg[:name] || msg[:id], + response: { + content: msg[:content], + }, + }, + }, + } + end end end end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb deleted file mode 100644 index 3b5675f1..00000000 --- a/lib/completions/dialects/llama2_classic.rb +++ /dev/null @@ -1,68 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Llama2Classic < Dialect - class << self - def can_translate?(model_name) - %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) - end - - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - - def translate - messages = prompt.messages - - llama2_prompt = - trim_messages(messages).reduce(+"") do |memo, msg| - next(memo) if msg[:type] == :tool_call - - if msg[:type] == :system - memo << (<<~TEXT).strip - [INST] - <> - #{msg[:content]} - #{build_tools_prompt} - <> - [/INST] - TEXT - elsif msg[:type] == :model - memo << "\n#{msg[:content]}" - elsif msg[:type] == :tool - JSON.parse(msg[:content], symbolize_names: true) - memo << "\n[INST]\n" - - memo << (<<~TEXT).strip - - - #{msg[:id]} - - #{msg[:content]} - - - - [/INST] - TEXT - else - memo << "\n[INST]#{msg[:content]}[/INST]" - end - - memo - end - - llama2_prompt << "\n" if llama2_prompt.ends_with?("[/INST]") - - llama2_prompt - end - - def max_prompt_tokens - SiteSetting.ai_hugging_face_token_limit - end - end - end - end -end diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb new file mode 100644 index 00000000..7752a876 --- /dev/null +++ b/lib/completions/dialects/mistral.rb @@ -0,0 +1,57 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Mistral < Dialect + class << self + def can_translate?(model_name) + %w[ + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.2 + mistral + ].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::MixtralTokenizer + end + end + + def tools + @tools ||= tools_dialect.translated_tools + end + + def max_prompt_tokens + 32_000 + end + + private + + def system_msg(msg) + { role: "assistant", content: "#{msg[:content]}" } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + + def user_msg(msg) + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] + + { role: "user", content: content } + end + end + end + end +end diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb deleted file mode 100644 index 425d741e..00000000 --- a/lib/completions/dialects/mixtral.rb +++ /dev/null @@ -1,57 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class Mixtral < Dialect - class << self - def can_translate?(model_name) - %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( - model_name, - ) - end - - def tokenizer - DiscourseAi::Tokenizer::MixtralTokenizer - end - end - - def translate - messages = prompt.messages - - mixtral_prompt = - trim_messages(messages).reduce(+"") do |memo, msg| - if msg[:type] == :tool_call - memo << "\n" - memo << tool_call_to_xml(msg) - elsif msg[:type] == :system - memo << (<<~TEXT).strip - [INST] - #{msg[:content]} - #{build_tools_prompt} - [/INST] Ok - TEXT - elsif msg[:type] == :model - memo << "\n#{msg[:content]}" - elsif msg[:type] == :tool - memo << "\n" - memo << tool_result_to_xml(msg) - else - memo << "\n[INST]#{msg[:content]}[/INST]" - end - - memo - end - - mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]") - - mixtral_prompt - end - - def max_prompt_tokens - 32_000 - end - end - end - end -end diff --git a/lib/completions/dialects/open_ai_tools.rb b/lib/completions/dialects/open_ai_tools.rb new file mode 100644 index 00000000..b990e379 --- /dev/null +++ b/lib/completions/dialects/open_ai_tools.rb @@ -0,0 +1,62 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class OpenAiTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + raw_tools.map do |t| + tool = t.dup + + tool[:parameters] = t[:parameters] + .to_a + .reduce({ type: "object", properties: {}, required: [] }) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] + + memo[:properties][name] = p.except(:name, :required, :item_type) + + memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] + memo + end + + { type: "function", function: tool } + end + end + + def instructions + "" # Noop. Tools are listed separate. + end + + def from_raw_tool_call(raw_message) + call_details = JSON.parse(raw_message[:content], symbolize_names: true) + call_details[:arguments] = call_details[:arguments].to_json + call_details[:name] = raw_message[:name] + + { + role: "assistant", + content: nil, + tool_calls: [{ type: "function", function: call_details, id: raw_message[:id] }], + } + end + + def from_raw_tool(raw_message) + { + role: "tool", + tool_call_id: raw_message[:id], + content: raw_message[:content], + name: raw_message[:name], + } + end + + private + + attr_reader :raw_tools + end + end + end +end diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb deleted file mode 100644 index 4ec42a36..00000000 --- a/lib/completions/dialects/orca_style.rb +++ /dev/null @@ -1,59 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class OrcaStyle < Dialect - class << self - def can_translate?(model_name) - %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) - end - - def tokenizer - DiscourseAi::Tokenizer::Llama2Tokenizer - end - end - - def translate - messages = prompt.messages - trimmed_messages = trim_messages(messages) - - # Need to include this differently - last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil - - llama2_prompt = - trimmed_messages.reduce(+"") do |memo, msg| - if msg[:type] == :tool_call - memo << "\n### Assistant:\n" - memo << tool_call_to_xml(msg) - elsif msg[:type] == :system - memo << (<<~TEXT).strip - ### System: - #{msg[:content]} - #{build_tools_prompt} - TEXT - elsif msg[:type] == :model - memo << "\n### Assistant:\n#{msg[:content]}" - elsif msg[:type] == :tool - memo << "\n### User:\n" - memo << tool_result_to_xml(msg) - else - memo << "\n### User:\n#{msg[:content]}" - end - - memo - end - - llama2_prompt << "\n### Assistant:\n" - llama2_prompt << "#{last_message[:content]}:" if last_message - - llama2_prompt - end - - def max_prompt_tokens - SiteSetting.ai_hugging_face_token_limit - end - end - end - end -end diff --git a/lib/completions/dialects/xml_tools.rb b/lib/completions/dialects/xml_tools.rb new file mode 100644 index 00000000..47988a71 --- /dev/null +++ b/lib/completions/dialects/xml_tools.rb @@ -0,0 +1,125 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class XmlTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + raw_tools.reduce(+"") do |tools, function| + parameters = +"" + if function[:parameters].present? + function[:parameters].each do |parameter| + parameters << <<~PARAMETER + + #{parameter[:name]} + #{parameter[:type]} + #{parameter[:description]} + #{parameter[:required]} + PARAMETER + if parameter[:enum] + parameters << "#{parameter[:enum].join(",")}\n" + end + parameters << "\n" + end + end + + tools << <<~TOOLS + + #{function[:name]} + #{function[:description]} + + #{parameters} + + TOOLS + end + end + + def instructions + return "" if raw_tools.blank? + + has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } + + (<<~TEXT).strip + #{tool_preamble(include_array_tip: has_arrays)} + + #{translated_tools} + TEXT + end + + def from_raw_tool(raw_message) + (<<~TEXT).strip + + + #{raw_message[:name] || raw_message[:id]} + + #{raw_message[:content]} + + + + TEXT + end + + def from_raw_tool_call(raw_message) + parsed = JSON.parse(raw_message[:content], symbolize_names: true) + parameters = +"" + + if parsed[:arguments] + parameters << "\n" + parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}\n" } + parameters << "\n" + end + + (<<~TEXT).strip + + + #{raw_message[:name] || parsed[:name]} + #{parameters} + + TEXT + end + + private + + attr_reader :raw_tools + + def tool_preamble(include_array_tip: true) + array_tip = + if include_array_tip + <<~TEXT + If a parameter type is an array, return an array of values. For example: + <$PARAMETER_NAME>["one","two","three"] + TEXT + else + "" + end + + <<~TEXT + In this environment you have access to a set of tools you can use to answer the user's question. + You may call them like this. + + + + $TOOL_NAME + + <$PARAMETER_NAME>$PARAMETER_VALUE + ... + + + + #{array_tip} + If you wish to call multiple function in one reply, wrap multiple + block in a single block. + + Always prefer to lead with tool calls, if you need to execute any. + Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc. + Here are the complete list of tools available: + TEXT + end + end + end + end +end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 8c27a269..ee9d4f17 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -62,7 +62,7 @@ module DiscourseAi # this is an approximation, we will update it later if request goes through def prompt_size(prompt) - super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) + tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) end def model_uri diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 5f17cad1..b7f3e8bf 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -51,7 +51,7 @@ module DiscourseAi def prompt_size(prompt) # approximation - super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) + tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) end def model_uri diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 0766a9a4..7a914a1e 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -19,6 +19,8 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Cohere, ] + endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development? + if Rails.env.test? || Rails.env.development? endpoints << DiscourseAi::Completions::Endpoints::Fake end @@ -67,6 +69,10 @@ module DiscourseAi false end + def use_ssl? + true + end + def perform_completion!(dialect, user, model_params = {}, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) @@ -78,7 +84,7 @@ module DiscourseAi FinalDestination::HTTP.start( model_uri.host, model_uri.port, - use_ssl: true, + use_ssl: use_ssl?, read_timeout: TIMEOUT, open_timeout: TIMEOUT, write_timeout: TIMEOUT, @@ -315,7 +321,7 @@ module DiscourseAi end def extract_prompt_for_tokenizer(prompt) - prompt + prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end def build_buffer diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index 5542c73f..d6237c05 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -8,14 +8,9 @@ module DiscourseAi def can_contact?(endpoint_name, model_name) return false unless endpoint_name == "hugging_face" - %w[ - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - ].include?(model_name) + %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( + model_name, + ) end def dependant_setting_names @@ -31,24 +26,21 @@ module DiscourseAi end end - def default_options - { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } } - end - def normalize_model_params(model_params) model_params = model_params.dup + # max_tokens, temperature are already supported if model_params[:stop_sequences] model_params[:stop] = model_params.delete(:stop_sequences) end - if model_params[:max_tokens] - model_params[:max_new_tokens] = model_params.delete(:max_tokens) - end - model_params end + def default_options + { model: model, temperature: 0.7 } + end + def provider_id AiApiAuditLog::Provider::HuggingFaceTextGeneration end @@ -61,13 +53,14 @@ module DiscourseAi def prepare_payload(prompt, model_params, _dialect) default_options - .merge(inputs: prompt) + .merge(model_params) + .merge(messages: prompt) .tap do |payload| - payload[:parameters].merge!(model_params) + if !payload[:max_tokens] + token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 - token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 - - payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) + payload[:max_tokens] = token_limit - prompt_size(prompt) + end payload[:stream] = true if @streaming_mode end @@ -85,16 +78,13 @@ module DiscourseAi end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed - if @streaming_mode - # Last chunk contains full response, which we already yielded. - return if parsed.dig(:token, :special) + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - parsed.dig(:token, :text).to_s - else - parsed[0][:generated_text].to_s - end + response_h.dig(:content) end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb new file mode 100644 index 00000000..0fd748d4 --- /dev/null +++ b/lib/completions/endpoints/ollama.rb @@ -0,0 +1,89 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Ollama < Base + class << self + def can_contact?(endpoint_name, model_name) + endpoint_name == "ollama" && %w[mistral].include?(model_name) + end + + def dependant_setting_names + %w[ai_ollama_endpoint] + end + + def correctly_configured?(_model_name) + SiteSetting.ai_ollama_endpoint.present? + end + + def endpoint_name(model_name) + "Ollama - #{model_name}" + end + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + # max_tokens, temperature are already supported + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + model_params + end + + def default_options + { max_tokens: 2000, model: model } + end + + def provider_id + AiApiAuditLog::Provider::Ollama + end + + def use_ssl? + false + end + + private + + def model_uri + URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions") + end + + def prepare_payload(prompt, model_params, _dialect) + default_options + .merge(model_params) + .merge(messages: prompt) + .tap { |payload| payload[:stream] = true if @streaming_mode } + end + + def prepare_request(payload) + headers = { "Content-Type" => "application/json" } + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def partials_from(decoded_chunk) + decoded_chunk + .split("\n") + .map do |line| + data = line.split("data: ", 2)[1] + data == "[DONE]" ? nil : data + end + .compact + end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed + + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + response_h.dig(:content) + end + end + end + end +end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 16e0b886..2ccd817e 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -153,10 +153,6 @@ module DiscourseAi .compact end - def extract_prompt_for_tokenizer(prompt) - prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") - end - def has_tool?(_response_data) @has_function_call end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 310bcdc9..7db1452d 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -7,14 +7,9 @@ module DiscourseAi class << self def can_contact?(endpoint_name, model_name) endpoint_name == "vllm" && - %w[ - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - ].include?(model_name) + %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?( + model_name, + ) end def dependant_setting_names @@ -54,9 +49,9 @@ module DiscourseAi def model_uri service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv) if service.present? - api_endpoint = "https://#{service.target}:#{service.port}/v1/completions" + api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" else - api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions" + api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions" end @uri ||= URI(api_endpoint) end @@ -64,7 +59,7 @@ module DiscourseAi def prepare_payload(prompt, model_params, _dialect) default_options .merge(model_params) - .merge(prompt: prompt) + .merge(messages: prompt) .tap { |payload| payload[:stream] = true if @streaming_mode } end @@ -76,15 +71,6 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - - # half a line sent here - return if !parsed - - parsed.dig(:text) - end - def partials_from(decoded_chunk) decoded_chunk .split("\n") @@ -94,6 +80,16 @@ module DiscourseAi end .compact end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) + # half a line sent here + return if !parsed + + response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + + response_h.dig(:content) + end end end end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index c54c0dd4..47190644 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -31,21 +31,10 @@ module DiscourseAi claude-3-opus ], anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus], - vllm: %w[ - mistralai/Mixtral-8x7B-Instruct-v0.1 - mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf - ], + vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2], hugging_face: %w[ mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 - Llama2-*-chat-hf - Llama2-chat-hf ], cohere: %w[command-light command command-r command-r-plus], open_ai: %w[ @@ -57,7 +46,10 @@ module DiscourseAi gpt-4-vision-preview ], google: %w[gemini-pro gemini-1.5-pro], - }.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? } + }.tap do |h| + h[:ollama] = ["mistral"] if Rails.env.development? + h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? + end end def valid_provider_models @@ -120,8 +112,6 @@ module DiscourseAi @gateway = gateway end - delegate :tokenizer, to: :dialect_klass - # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param user { User } - User requesting the summary. # @@ -184,6 +174,8 @@ module DiscourseAi dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens end + delegate :tokenizer, to: :dialect_klass + attr_reader :model_name private diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb index 9e60eb67..37c72725 100644 --- a/lib/summarization/entry_point.rb +++ b/lib/summarization/entry_point.rb @@ -10,14 +10,6 @@ module DiscourseAi Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000), Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), - Models::Llama2.new( - "hugging_face:Llama2-chat-hf", - max_tokens: SiteSetting.ai_hugging_face_token_limit, - ), - Models::Llama2FineTunedOrcaStyle.new( - "hugging_face:StableBeluga2", - max_tokens: SiteSetting.ai_hugging_face_token_limit, - ), Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000), ] diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb index 2410f415..c54d1838 100644 --- a/spec/lib/completions/dialects/dialect_spec.rb +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -17,47 +17,6 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect end RSpec.describe DiscourseAi::Completions::Dialects::Dialect do - describe "#build_tools_prompt" do - it "can exclude array instructions" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [ - { - name: "weather", - description: "lookup weather in a city", - parameters: [{ name: "city", type: "string", description: "city name", required: true }], - }, - ] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).not_to include("array") - end - - it "can include array instructions" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [ - { - name: "weather", - description: "lookup weather in a city", - parameters: [{ name: "city", type: "array", description: "city names", required: true }], - }, - ] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).to include("array") - end - - it "does not break if there are no params" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.tools = [{ name: "categories", description: "lookup all categories" }] - - dialect = TestDialect.new(prompt, "test") - - expect(dialect.build_tools_prompt).not_to include("array") - end - end - describe "#trim_messages" do it "should trim tool messages if tool_calls are trimmed" do prompt = DiscourseAi::Completions::Prompt.new("12345") diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb deleted file mode 100644 index 0242ebf6..00000000 --- a/spec/lib/completions/dialects/llama2_classic_spec.rb +++ /dev/null @@ -1,62 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do - let(:model_name) { "Llama2-chat-hf" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - [INST] - <> - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - <> - [/INST] - [INST]#{context.simple_user_input}[/INST] - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT) - [INST] - <> - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - <> - [/INST] - [INST]This is a message by a user[/INST] - I'm a previous bot reply, that's why there's no user - [INST]This is a new message by a user[/INST] - [INST] - - - tool_id - - "I'm a tool result" - - - - [/INST] - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - translated = context.long_user_input_scenario - - expect(translated.length).to be < context.long_message_text.length - end - end -end diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb deleted file mode 100644 index 499dad73..00000000 --- a/spec/lib/completions/dialects/mixtral_spec.rb +++ /dev/null @@ -1,66 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do - let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - [INST] - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - [/INST] Ok - [INST]#{context.simple_user_input}[/INST] - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT).strip - [INST] - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - [/INST] Ok - [INST]This is a message by a user[/INST] - I'm a previous bot reply, that's why there's no user - [INST]This is a new message by a user[/INST] - - - get_weather - - Sydney - c - - - - - - get_weather - - "I'm a tool result" - - - - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - length = 6_000 - translated = context.long_user_input_scenario(length: length) - - expect(translated.length).to be < context.long_message_text(length: length).length - end - end -end diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb deleted file mode 100644 index 63b414b8..00000000 --- a/spec/lib/completions/dialects/orca_style_spec.rb +++ /dev/null @@ -1,71 +0,0 @@ -# frozen_string_literal: true - -require_relative "dialect_context" - -RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do - let(:model_name) { "StableBeluga2" } - let(:context) { DialectContext.new(described_class, model_name) } - - describe "#translate" do - it "translates a prompt written in our generic format to the Llama2 format" do - llama2_classic_version = <<~TEXT - ### System: - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - ### User: - #{context.simple_user_input} - ### Assistant: - TEXT - - translated = context.system_user_scenario - - expect(translated).to eq(llama2_classic_version) - end - - it "translates tool messages" do - expected = +(<<~TEXT) - ### System: - #{context.system_insts} - #{described_class.tool_preamble(include_array_tip: false)} - - #{context.dialect_tools} - ### User: - This is a message by a user - ### Assistant: - I'm a previous bot reply, that's why there's no user - ### User: - This is a new message by a user - ### Assistant: - - - get_weather - - Sydney - c - - - - ### User: - - - get_weather - - "I'm a tool result" - - - - ### Assistant: - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - translated = context.long_user_input_scenario - - expect(translated.length).to be < context.long_message_text.length - end - end -end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index bfdfa74e..5b9bd9f5 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -4,7 +4,20 @@ require_relative "endpoint_compliance" class HuggingFaceMock < EndpointMock def response(content) - [{ generated_text: content }] + { + id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", + object: "chat.completion", + created: 1_678_464_820, + model: "Llama2-*-chat-hf", + usage: { + prompt_tokens: 337, + completion_tokens: 162, + total_tokens: 499, + }, + choices: [ + { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + ], + } end def stub_response(prompt, response_text, tool_call: false) @@ -14,26 +27,32 @@ class HuggingFaceMock < EndpointMock .to_return(status: 200, body: JSON.dump(response(response_text))) end - def stream_line(delta, deltas, finish_reason: nil) + def stream_line(delta, finish_reason: nil) +"data: " << { - token: { - id: 29_889, - text: delta, - logprob: -0.08319092, - special: !!finish_reason, - }, - generated_text: finish_reason ? deltas.join : nil, - details: nil, + id: "chatcmpl-#{SecureRandom.hex}", + object: "chat.completion.chunk", + created: 1_681_283_881, + model: "Llama2-*-chat-hf", + choices: [{ delta: { content: delta } }], + finish_reason: finish_reason, + index: 0, }.to_json end + def stub_raw(chunks) + WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return( + status: 200, + body: chunks, + ) + end + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| if index == (deltas.length - 1) - stream_line(deltas[index], deltas, finish_reason: true) + stream_line(deltas[index], finish_reason: "stop_sequence") else - stream_line(deltas[index], deltas) + stream_line(deltas[index]) end end @@ -43,16 +62,18 @@ class HuggingFaceMock < EndpointMock .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .with(body: request_body(prompt, stream: true)) .to_return(status: 200, body: chunks) + + yield if block_given? end - def request_body(prompt, stream: false) + def request_body(prompt, stream: false, tool_call: false) model .default_options - .merge(inputs: prompt) - .tap do |payload| - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - + .merge(messages: prompt) + .tap do |b| + b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - model.prompt_size(prompt) - payload[:stream] = true if stream + b[:stream] = true if stream end .to_json end @@ -70,7 +91,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do let(:hf_mock) { HuggingFaceMock.new(endpoint) } let(:compliance) do - EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user) + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) end describe "#perform_completion!" do diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 52d87007..d879ad09 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -6,7 +6,7 @@ class VllmMock < EndpointMock def response(content) { id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", - object: "text_completion", + object: "chat.completion", created: 1_678_464_820, model: "mistralai/Mixtral-8x7B-Instruct-v0.1", usage: { @@ -14,14 +14,16 @@ class VllmMock < EndpointMock completion_tokens: 162, total_tokens: 499, }, - choices: [{ text: content, finish_reason: "stop", index: 0 }], + choices: [ + { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, + ], } end def stub_response(prompt, response_text, tool_call: false) WebMock - .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: model.default_options.merge(prompt: prompt).to_json) + .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") + .with(body: model.default_options.merge(messages: prompt).to_json) .to_return(status: 200, body: JSON.dump(response(response_text))) end @@ -30,7 +32,7 @@ class VllmMock < EndpointMock id: "cmpl-#{SecureRandom.hex}", created: 1_681_283_881, model: "mistralai/Mixtral-8x7B-Instruct-v0.1", - choices: [{ text: delta, finish_reason: finish_reason, index: 0 }], + choices: [{ delta: { content: delta } }], index: 0, }.to_json end @@ -48,8 +50,8 @@ class VllmMock < EndpointMock chunks = (chunks.join("\n\n") << "data: [DONE]").split("") WebMock - .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") - .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) + .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions") + .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) .to_return(status: 200, body: chunks) end end @@ -67,14 +69,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do let(:anthropic_mock) { VllmMock.new(endpoint) } let(:compliance) do - EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user) + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) end - let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) } + let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) } let(:prompt) { dialect.translate } - let(:request_body) { model.default_options.merge(prompt: prompt).to_json } - let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } + let(:request_body) { model.default_options.merge(messages: prompt).to_json } + let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index 3aeeb04a..e91a4a60 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -3,8 +3,8 @@ RSpec.describe DiscourseAi::Completions::Llm do subject(:llm) do described_class.new( - DiscourseAi::Completions::Dialects::OrcaStyle, - nil, + DiscourseAi::Completions::Dialects::Mistral, + canned_response, "hugging_face:Upstage-Llama-2-*-instruct-v2", gateway: canned_response, )