diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index c3121624..5cc75d6e 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -5,8 +5,8 @@ module DiscourseAi module Dialects class ChatGpt < Dialect class << self - def can_translate?(model_provider) - model_provider == "open_ai" || model_provider == "azure" + def can_translate?(llm_model) + llm_model.provider == "open_ai" || llm_model.provider == "azure" end end diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 916bd90c..a9c0aba7 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -5,8 +5,10 @@ module DiscourseAi module Dialects class Claude < Dialect class << self - def can_translate?(provider_name) - provider_name == "anthropic" || provider_name == "aws_bedrock" + def can_translate?(llm_model) + llm_model.provider == "anthropic" || + (llm_model.provider == "aws_bedrock") && + (llm_model.name.include?("anthropic") || llm_model.name.include?("claude")) end end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 43e252a7..25561433 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -6,8 +6,8 @@ module DiscourseAi module Completions module Dialects class Command < Dialect - def self.can_translate?(model_provider) - model_provider == "cohere" + def self.can_translate?(llm_model) + llm_model.provider == "cohere" end VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index f97da195..041a5f1e 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -5,7 +5,7 @@ module DiscourseAi module Dialects class Dialect class << self - def can_translate?(model_provider) + def can_translate?(llm_model) raise NotImplemented end @@ -17,11 +17,12 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Ollama, DiscourseAi::Completions::Dialects::Mistral, + DiscourseAi::Completions::Dialects::Nova, DiscourseAi::Completions::Dialects::OpenAiCompatible, ] end - def dialect_for(model_provider) + def dialect_for(llm_model) dialects = [] if Rails.env.test? || Rails.env.development? @@ -30,7 +31,7 @@ module DiscourseAi dialects = dialects.concat(all_dialects) - dialect = dialects.find { |d| d.can_translate?(model_provider) } + dialect = dialects.find { |d| d.can_translate?(llm_model) } raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect dialect diff --git a/lib/completions/dialects/fake.rb b/lib/completions/dialects/fake.rb index 464d3279..cda44110 100644 --- a/lib/completions/dialects/fake.rb +++ b/lib/completions/dialects/fake.rb @@ -5,8 +5,8 @@ module DiscourseAi module Dialects class Fake < Dialect class << self - def can_translate?(model_name) - model_name == "fake" + def can_translate?(llm_model) + llm_model.provider == "fake" end end diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 899634a1..563ed88b 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -5,8 +5,8 @@ module DiscourseAi module Dialects class Gemini < Dialect class << self - def can_translate?(model_provider) - model_provider == "google" + def can_translate?(llm_model) + llm_model.provider == "google" end end @@ -80,7 +80,7 @@ module DiscourseAi end def beta_api? - @beta_api ||= llm_model.name.start_with?("gemini-1.5") + @beta_api ||= !llm_model.name.start_with?("gemini-1.0") end def system_msg(msg) diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb index d6968e82..eaf5eab8 100644 --- a/lib/completions/dialects/mistral.rb +++ b/lib/completions/dialects/mistral.rb @@ -7,8 +7,8 @@ module DiscourseAi module Dialects class Mistral < ChatGpt class << self - def can_translate?(model_provider) - model_provider == "mistral" + def can_translate?(llm_model) + llm_model.provider == "mistral" end end diff --git a/lib/completions/dialects/nova.rb b/lib/completions/dialects/nova.rb new file mode 100644 index 00000000..aa184a7a --- /dev/null +++ b/lib/completions/dialects/nova.rb @@ -0,0 +1,177 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Nova < Dialect + class << self + def can_translate?(llm_model) + llm_model.provider == "aws_bedrock" && llm_model.name.include?("amazon.nova") + end + end + + class NovaPrompt + attr_reader :system, :messages, :inference_config, :tool_config + + def initialize(system, messages, inference_config = nil, tool_config = nil) + @system = system + @messages = messages + @inference_config = inference_config + @tool_config = tool_config + end + + def system_prompt + # small hack for size estimation + system.to_s + end + + def has_tools? + tool_config.present? + end + + def to_payload(options = nil) + stop_sequences = options[:stop_sequences] + max_tokens = options[:max_tokens] + + inference_config = options&.slice(:temperature, :top_p, :top_k) + + inference_config[:stopSequences] = stop_sequences if stop_sequences.present? + + inference_config[:max_new_tokens] = max_tokens if max_tokens.present? + + result = { system: system, messages: messages } + result[:inferenceConfig] = inference_config if inference_config.present? + result[:toolConfig] = tool_config if tool_config.present? + + result + end + end + + def translate + messages = super + + system = messages.shift[:content] if messages.first&.dig(:role) == "system" + nova_messages = messages.map { |msg| { role: msg[:role], content: build_content(msg) } } + + inference_config = build_inference_config + tool_config = tools_dialect.translated_tools if native_tool_support? + + NovaPrompt.new( + system.presence && [{ text: system }], + nova_messages, + inference_config, + tool_config, + ) + end + + def max_prompt_tokens + llm_model.max_prompt_tokens + end + + def native_tool_support? + !llm_model.lookup_custom_param("disable_native_tools") + end + + def tools_dialect + if native_tool_support? + @tools_dialect ||= DiscourseAi::Completions::Dialects::NovaTools.new(prompt.tools) + else + super + end + end + + private + + def build_content(msg) + content = [] + + existing_content = msg[:content] + + if existing_content.is_a?(Hash) + content << existing_content + elsif existing_content.is_a?(String) + content << { text: existing_content } + end + + msg[:images]&.each { |image| content << image } + + content + end + + def build_inference_config + return unless opts[:inference_config] + + config = {} + ic = opts[:inference_config] + + config[:max_new_tokens] = ic[:max_new_tokens] if ic[:max_new_tokens] + config[:temperature] = ic[:temperature] if ic[:temperature] + config[:top_p] = ic[:top_p] if ic[:top_p] + config[:top_k] = ic[:top_k] if ic[:top_k] + config[:stopSequences] = ic[:stop_sequences] if ic[:stop_sequences] + + config.present? ? config : nil + end + + def detect_format(mime_type) + case mime_type + when "image/jpeg" + "jpeg" + when "image/png" + "png" + when "image/gif" + "gif" + when "image/webp" + "webp" + else + "jpeg" # default + end + end + + def system_msg(msg) + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg + end + + def user_msg(msg) + images = nil + if vision_support? + encoded_uploads = prompt.encoded_uploads(msg) + encoded_uploads&.each do |upload| + images ||= [] + images << { + image: { + format: upload[:format] || detect_format(upload[:mime_type]), + source: { + bytes: upload[:base64], + }, + }, + } + end + end + + { role: "user", content: msg[:content], images: images } + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def tool_msg(msg) + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } + end + + def tool_call_msg(msg) + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } + end + end + end + end +end diff --git a/lib/completions/dialects/nova_tools.rb b/lib/completions/dialects/nova_tools.rb new file mode 100644 index 00000000..67cff2ea --- /dev/null +++ b/lib/completions/dialects/nova_tools.rb @@ -0,0 +1,83 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class NovaTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + return if !@raw_tools.present? + + # note: forced tools are not supported yet toolChoice is always auto + { + tools: + @raw_tools.map do |tool| + { + toolSpec: { + name: tool[:name], + description: tool[:description], + inputSchema: { + json: convert_tool_to_input_schema(tool), + }, + }, + } + end, + } + end + + # nativ tools require no system instructions + def instructions + "" + end + + def from_raw_tool_call(raw_message) + { + toolUse: { + toolUseId: raw_message[:id], + name: raw_message[:name], + input: JSON.parse(raw_message[:content])["arguments"], + }, + } + end + + def from_raw_tool(raw_message) + { + toolResult: { + toolUseId: raw_message[:id], + content: [{ json: JSON.parse(raw_message[:content]) }], + }, + } + end + + private + + def convert_tool_to_input_schema(tool) + tool = tool.transform_keys(&:to_sym) + properties = {} + tool[:parameters].each do |param| + schema = {} + type = param[:type] + type = "string" if !%w[string number boolean integer array].include?(type) + + schema[:type] = type + + if enum = param[:enum] + schema[:enum] = enum + end + + schema[:items] = { type: param[:item_type] } if type == "array" + + schema[:required] = true if param[:required] + + properties[param[:name]] = schema + end + + { type: "object", properties: properties } + end + end + end + end +end diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 54114400..4546c827 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -5,8 +5,8 @@ module DiscourseAi module Dialects class Ollama < Dialect class << self - def can_translate?(model_provider) - model_provider == "ollama" + def can_translate?(llm_model) + llm_model.provider == "ollama" end end diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index d22d9708..2d648ac1 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -5,7 +5,8 @@ module DiscourseAi module Dialects class OpenAiCompatible < Dialect class << self - def can_translate?(_model_name) + def can_translate?(_llm_model) + # fallback dialect true end end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 1c6f67f1..882c3467 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -6,6 +6,8 @@ module DiscourseAi module Completions module Endpoints class AwsBedrock < Base + attr_reader :dialect + def self.can_contact?(model_provider) model_provider == "aws_bedrock" end @@ -19,10 +21,15 @@ module DiscourseAi end def default_options(dialect) - max_tokens = 4096 - max_tokens = 8192 if bedrock_model_id.match?(/3.5/) + options = + if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude) + max_tokens = 4096 + max_tokens = 8192 if bedrock_model_id.match?(/3.5/) - options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" } + { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" } + else + {} + end options[:stop_sequences] = [""] if !dialect.native_tool_support? && dialect.prompt.has_tools? @@ -86,17 +93,25 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) @native_tool_support = dialect.native_tool_support? + @dialect = dialect - payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) - payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + payload = nil - if prompt.has_tools? - payload[:tools] = prompt.tools - if dialect.tool_choice.present? - payload[:tool_choice] = { type: "tool", name: dialect.tool_choice } + if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude) + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) + payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + + if prompt.has_tools? + payload[:tools] = prompt.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "tool", name: dialect.tool_choice } + end end + elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova) + payload = prompt.to_payload(default_options(dialect).merge(model_params)) + else + raise "Unsupported dialect" end - payload end @@ -151,6 +166,11 @@ module DiscourseAi i = 0 while decoded parsed = JSON.parse(decoded.payload.string) + if exception = decoded.headers[":exception-type"] + Rails.logger.error("#{self.class.name}: #{exception}: #{parsed}") + # TODO based on how often this happens, we may want to raise so we + # can retry, this may catch rate limits for example + end # perhaps some control message we can just ignore messages << Base64.decode64(parsed["bytes"]) if parsed && parsed["bytes"] @@ -180,8 +200,15 @@ module DiscourseAi end def processor - @processor ||= - DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) + if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude) + @processor ||= + DiscourseAi::Completions::AnthropicMessageProcessor.new( + streaming_mode: @streaming_mode, + ) + else + @processor ||= + DiscourseAi::Completions::NovaMessageProcessor.new(streaming_mode: @streaming_mode) + end end def xml_tools_enabled? diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 1707460c..51ba0464 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -176,8 +176,7 @@ module DiscourseAi raise UNKNOWN_MODEL if llm_model.nil? - model_provider = llm_model.provider - dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider) + dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model) if @canned_response if @canned_llm && @canned_llm != model @@ -187,6 +186,7 @@ module DiscourseAi return new(dialect_klass, nil, llm_model, gateway: @canned_response) end + model_provider = llm_model.provider gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider) new(dialect_klass, gateway_klass, llm_model) diff --git a/lib/completions/nova_message_processor.rb b/lib/completions/nova_message_processor.rb new file mode 100644 index 00000000..efe54330 --- /dev/null +++ b/lib/completions/nova_message_processor.rb @@ -0,0 +1,94 @@ +# frozen_string_literal: true + +class DiscourseAi::Completions::NovaMessageProcessor + class NovaToolCall + attr_reader :name, :raw_json, :id + + def initialize(name, id, partial_tool_calls: false) + @name = name + @id = id + @raw_json = +"" + @tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {}) + @streaming_parser = + DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls + end + + def append(json) + @raw_json << json + @streaming_parser << json if @streaming_parser + end + + def notify_progress(key, value) + @tool_call.partial = true + @tool_call.parameters[key.to_sym] = value + @has_new_data = true + end + + def has_partial? + @has_new_data + end + + def partial_tool_call + @has_new_data = false + @tool_call + end + + def to_tool_call + parameters = JSON.parse(raw_json, symbolize_names: true) + # we dupe to avoid poisoning the original tool call + @tool_call = @tool_call.dup + @tool_call.partial = false + @tool_call.parameters = parameters + @tool_call + end + end + + attr_reader :tool_calls, :input_tokens, :output_tokens + + def initialize(streaming_mode:, partial_tool_calls: false) + @streaming_mode = streaming_mode + @tool_calls = [] + @current_tool_call = nil + @partial_tool_calls = partial_tool_calls + end + + def to_tool_calls + @tool_calls.map { |tool_call| tool_call.to_tool_call } + end + + def process_streamed_message(parsed) + return if !parsed + + result = nil + + if tool_start = parsed.dig(:contentBlockStart, :start, :toolUse) + @current_tool_call = NovaToolCall.new(tool_start[:name], tool_start[:toolUseId]) + end + + if tool_progress = parsed.dig(:contentBlockDelta, :delta, :toolUse, :input) + @current_tool_call.append(tool_progress) + end + + result = @current_tool_call.to_tool_call if parsed[:contentBlockStop] && @current_tool_call + + if metadata = parsed[:metadata] + @input_tokens = metadata.dig(:usage, :inputTokens) + @output_tokens = metadata.dig(:usage, :outputTokens) + end + + result || parsed.dig(:contentBlockDelta, :delta, :text) + end + + def process_message(payload) + result = [] + parsed = payload + parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String) + + result << parsed.dig(:output, :message, :content, 0, :text) + + @input_tokens = parsed.dig(:usage, :inputTokens) + @output_tokens = parsed.dig(:usage, :outputTokens) + + result + end +end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb index 3195b3f5..37c0fcbb 100644 --- a/spec/fabricators/llm_model_fabricator.rb +++ b/spec/fabricators/llm_model_fabricator.rb @@ -64,6 +64,17 @@ Fabricator(:bedrock_model, from: :anthropic_model) do provider_params { { region: "us-east-1", access_key_id: "123456" } } end +Fabricator(:nova_model, from: :llm_model) do + display_name "Amazon Nova pro" + name "amazon.nova-pro-v1:0" + provider "aws_bedrock" + tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer" + max_prompt_tokens 300_000 + api_key "fake" + url "" + provider_params { { region: "us-east-1", access_key_id: "123456" } } +end + Fabricator(:cohere_model, from: :llm_model) do display_name "Cohere Command R+" name "command-r-plus" diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index 624c431f..0c9a9590 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,12 +1,12 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Claude do - let :opus_dialect_klass do - DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic") - end - fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") } + let :opus_dialect_klass do + DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model) + end + describe "#translate" do it "can insert OKs to make stuff interleve properly" do messages = [ diff --git a/spec/lib/completions/dialects/nova_spec.rb b/spec/lib/completions/dialects/nova_spec.rb new file mode 100644 index 00000000..865426e2 --- /dev/null +++ b/spec/lib/completions/dialects/nova_spec.rb @@ -0,0 +1,190 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::Dialects::Nova do + fab!(:llm_model) { Fabricate(:nova_model, vision_enabled: true) } + + let(:nova_dialect_klass) { DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model) } + + it "finds the right dialect" do + expect(nova_dialect_klass).to eq(DiscourseAi::Completions::Dialects::Nova) + end + + describe "#translate" do + it "properly formats a basic conversation" do + messages = [ + { type: :user, id: "user1", content: "Hello" }, + { type: :model, content: "Hi there!" }, + ] + + prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) + dialect = nova_dialect_klass.new(prompt, llm_model) + translated = dialect.translate + + expect(translated.system).to eq([{ text: "You are a helpful bot" }]) + expect(translated.messages).to eq( + [ + { role: "user", content: [{ text: "Hello" }] }, + { role: "assistant", content: [{ text: "Hi there!" }] }, + ], + ) + end + + context "with image content" do + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + let(:upload) do + UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + end + + it "properly formats messages with images" do + messages = [ + { type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] }, + ] + + prompt = DiscourseAi::Completions::Prompt.new(messages: messages) + + dialect = nova_dialect_klass.new(prompt, llm_model) + translated = dialect.translate + + encoded = prompt.encoded_uploads(messages.first).first[:base64] + + expect(translated.messages.first[:content]).to eq( + [ + { text: "What's in this image?" }, + { image: { format: "jpeg", source: { bytes: encoded } } }, + ], + ) + end + end + + context "with tools" do + it "properly formats tool configuration" do + tools = [ + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + ], + }, + ] + + messages = [{ type: :user, content: "What's the weather?" }] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful bot", + messages: messages, + tools: tools, + ) + + dialect = nova_dialect_klass.new(prompt, llm_model) + translated = dialect.translate + + expect(translated.tool_config).to eq( + { + tools: [ + { + toolSpec: { + name: "get_weather", + description: "Get the weather in a city", + inputSchema: { + json: { + type: "object", + properties: { + "location" => { + type: "string", + required: true, + }, + }, + }, + }, + }, + }, + ], + }, + ) + end + end + + context "with inference configuration" do + it "includes inference configuration when provided" do + messages = [{ type: :user, content: "Hello" }] + + prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) + + dialect = nova_dialect_klass.new(prompt, llm_model) + + options = { temperature: 0.7, top_p: 0.9, max_tokens: 100, stop_sequences: ["STOP"] } + + translated = dialect.translate + + expected = { + system: [{ text: "You are a helpful bot" }], + messages: [{ role: "user", content: [{ text: "Hello" }] }], + inferenceConfig: { + temperature: 0.7, + top_p: 0.9, + stopSequences: ["STOP"], + max_new_tokens: 100, + }, + } + + expect(translated.to_payload(options)).to eq(expected) + end + + it "omits inference configuration when not provided" do + messages = [{ type: :user, content: "Hello" }] + + prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) + + dialect = nova_dialect_klass.new(prompt, llm_model) + translated = dialect.translate + + expect(translated.inference_config).to be_nil + end + end + + it "handles tool calls and responses" do + tool_call_prompt = { name: "get_weather", arguments: { location: "London" } } + + messages = [ + { type: :user, id: "user1", content: "What's the weather in London?" }, + { type: :tool_call, name: "get_weather", id: "tool_id", content: tool_call_prompt.to_json }, + { type: :tool, id: "tool_id", content: "Sunny, 22°C".to_json }, + { type: :model, content: "The weather in London is sunny with 22°C" }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful bot", + messages: messages, + tools: [ + { + name: "get_weather", + description: "Get the weather in a city", + parameters: [ + { name: "location", type: "string", description: "the city name", required: true }, + ], + }, + ], + ) + + dialect = nova_dialect_klass.new(prompt, llm_model) + translated = dialect.translate + + expect(translated.messages.map { |m| m[:role] }).to eq(%w[user assistant user assistant]) + expect(translated.messages.last[:content]).to eq( + [{ text: "The weather in London is sunny with 22°C" }], + ) + end + end + + describe "#max_prompt_tokens" do + it "returns the model's max prompt tokens" do + prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot") + dialect = nova_dialect_klass.new(prompt, llm_model) + + expect(dialect.max_prompt_tokens).to eq(llm_model.max_prompt_tokens) + end + end +end diff --git a/spec/lib/completions/endpoints/nova_spec.rb b/spec/lib/completions/endpoints/nova_spec.rb new file mode 100644 index 00000000..d0f9a538 --- /dev/null +++ b/spec/lib/completions/endpoints/nova_spec.rb @@ -0,0 +1,284 @@ +# frozen_string_literal: true + +require_relative "endpoint_compliance" +require "aws-eventstream" +require "aws-sigv4" + +class BedrockMock < EndpointMock +end + +# nova is all implemented in bedrock endpoint, split out here +RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do + fab!(:user) + fab!(:nova_model) + + subject(:endpoint) { described_class.new(nova_model) } + + let(:bedrock_mock) { BedrockMock.new(endpoint) } + + let(:stream_url) do + "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke-with-response-stream" + end + + def encode_message(message) + wrapped = { bytes: Base64.encode64(message.to_json) }.to_json + io = StringIO.new(wrapped) + aws_message = Aws::EventStream::Message.new(payload: io) + Aws::EventStream::Encoder.new.encode(aws_message) + end + + it "should be able to make a simple request" do + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}") + + content = { + "output" => { + "message" => { + "content" => [{ "text" => "it is 2." }], + "role" => "assistant", + }, + }, + "stopReason" => "end_turn", + "usage" => { + "inputTokens" => 14, + "outputTokens" => 119, + "totalTokens" => 133, + "cacheReadInputTokenCount" => nil, + "cacheWriteInputTokenCount" => nil, + }, + }.to_json + + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke", + ).to_return(status: 200, body: content) + + response = proxy.generate("hello world", user: user) + expect(response).to eq("it is 2.") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(14) + expect(log.response_tokens).to eq(119) + end + + it "should be able to make a streaming request" do + messages = + [ + { messageStart: { role: "assistant" } }, + { contentBlockDelta: { delta: { text: "Hello" }, contentBlockIndex: 0 } }, + { contentBlockStop: { contentBlockIndex: 0 } }, + { contentBlockDelta: { delta: { text: "!" }, contentBlockIndex: 1 } }, + { contentBlockStop: { contentBlockIndex: 1 } }, + { + metadata: { + usage: { + inputTokens: 14, + outputTokens: 18, + }, + metrics: { + }, + trace: { + }, + }, + "amazon-bedrock-invocationMetrics": { + inputTokenCount: 14, + outputTokenCount: 18, + invocationLatency: 402, + firstByteLatency: 72, + }, + }, + ].map { |message| encode_message(message) } + + stub_request(:post, stream_url).to_return(status: 200, body: messages.join) + + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}") + responses = [] + proxy.generate("Hello!", user: user) { |partial| responses << partial } + + expect(responses).to eq(%w[Hello !]) + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(14) + expect(log.response_tokens).to eq(18) + end + + it "should support native streaming tool calls" do + #model.provider_params["disable_native_tools"] = true + #model.save! + + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}") + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful assistant.", + messages: [{ type: :user, content: "what is the time in EST" }], + ) + + tool = { + name: "time", + description: "Will look up the current time", + parameters: [ + { name: "timezone", description: "The timezone", type: "string", required: true }, + ], + } + + prompt.tools = [tool] + + messages = + [ + { messageStart: { role: "assistant" } }, + { + contentBlockStart: { + start: { + toolUse: { + toolUseId: "e1bd7033-7244-4408-b088-1d33cbcf0b67", + name: "time", + }, + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + toolUse: { + input: "{\"timezone\":\"EST\"}", + }, + }, + contentBlockIndex: 0, + }, + }, + { contentBlockStop: { contentBlockIndex: 0 } }, + { messageStop: { stopReason: "end_turn" } }, + { + metadata: { + usage: { + inputTokens: 481, + outputTokens: 28, + }, + metrics: { + }, + trace: { + }, + }, + "amazon-bedrock-invocationMetrics": { + inputTokenCount: 481, + outputTokenCount: 28, + invocationLatency: 383, + firstByteLatency: 57, + }, + }, + ].map { |message| encode_message(message) } + + request = nil + stub_request(:post, stream_url) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + response = [] + bedrock_mock.with_chunk_array_support do + proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial } + end + + parsed_request = JSON.parse(request.body) + expected = { + "system" => [{ "text" => "You are a helpful assistant." }], + "messages" => [{ "role" => "user", "content" => [{ "text" => "what is the time in EST" }] }], + "inferenceConfig" => { + "max_new_tokens" => 200, + }, + "toolConfig" => { + "tools" => [ + { + "toolSpec" => { + "name" => "time", + "description" => "Will look up the current time", + "inputSchema" => { + "json" => { + "type" => "object", + "properties" => { + "timezone" => { + "type" => "string", + "required" => true, + }, + }, + }, + }, + }, + }, + ], + }, + } + + expect(parsed_request).to eq(expected) + expect(response).to eq( + [ + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "e1bd7033-7244-4408-b088-1d33cbcf0b67", + parameters: { + "timezone" => "EST", + }, + ), + ], + ) + + # lets continue and ensure all messages are mapped correctly + prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111") + prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111") + + # lets just return the tool call again, this is about ensuring we encode the prompt right + stub_request(:post, stream_url) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + response = [] + bedrock_mock.with_chunk_array_support do + proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial } + end + + expected = { + system: [{ text: "You are a helpful assistant." }], + messages: [ + { role: "user", content: [{ text: "what is the time in EST" }] }, + { + role: "assistant", + content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }], + }, + { + role: "user", + content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }], + }, + ], + inferenceConfig: { + max_new_tokens: 200, + }, + toolConfig: { + tools: [ + { + toolSpec: { + name: "time", + description: "Will look up the current time", + inputSchema: { + json: { + type: "object", + properties: { + timezone: { + type: "string", + required: true, + }, + }, + }, + }, + }, + }, + ], + }, + } + + expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected) + end +end