diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index 71702469..40595b67 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -178,9 +178,16 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
+ # no bedrock support yet 18-03
"anthropic:claude-3-opus"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
- "anthropic:claude-3-sonnet"
+ if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
+ "claude-3-sonnet",
+ )
+ "aws_bedrock:claude-3-sonnet"
+ else
+ "anthropic:claude-3-sonnet"
+ end
else
nil
end
diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb
index 029b22f2..1ab2eb9f 100644
--- a/lib/automation/report_runner.rb
+++ b/lib/automation/report_runner.rb
@@ -216,13 +216,16 @@ Follow the provided writing composition instructions carefully and precisely ste
def translate_model(model)
return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model.start_with? "gpt"
- return "anthropic:#{model}" if model.start_with? "claude-3"
- if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
- "aws_bedrock:#{model}"
- else
- "anthropic:#{model}"
+ if model.start_with? "claude"
+ if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(model)
+ return "aws_bedrock:#{model}"
+ else
+ return "anthropic:#{model}"
+ end
end
+
+ raise "Unknown model #{model}"
end
private
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index 6043a042..616fd168 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -6,7 +6,9 @@ module DiscourseAi
class Claude < Dialect
class << self
def can_translate?(model_name)
- %w[claude-instant-1 claude-2].include?(model_name)
+ %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
+ model_name,
+ )
end
def tokenizer
@@ -14,53 +16,69 @@ module DiscourseAi
end
end
+ class ClaudePrompt
+ attr_reader :system_prompt
+ attr_reader :messages
+
+ def initialize(system_prompt, messages)
+ @system_prompt = system_prompt
+ @messages = messages
+ end
+ end
+
def translate
messages = prompt.messages
+ system_prompt = +""
- trimmed_messages = trim_messages(messages)
+ messages =
+ trim_messages(messages)
+ .map do |msg|
+ case msg[:type]
+ when :system
+ system_prompt << msg[:content]
+ nil
+ when :tool_call
+ { role: "assistant", content: tool_call_to_xml(msg) }
+ when :tool
+ { role: "user", content: tool_result_to_xml(msg) }
+ when :model
+ { role: "assistant", content: msg[:content] }
+ when :user
+ content = +""
+ content << "#{msg[:id]}: " if msg[:id]
+ content << msg[:content]
- # Need to include this differently
- last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
-
- claude_prompt =
- trimmed_messages.reduce(+"") do |memo, msg|
- if msg[:type] == :tool_call
- memo << "\n\nAssistant: #{tool_call_to_xml(msg)}"
- elsif msg[:type] == :system
- memo << "Human: " unless uses_system_message?
- memo << msg[:content]
- if prompt.tools.present?
- memo << "\n"
- memo << build_tools_prompt
+ { role: "user", content: content }
end
- elsif msg[:type] == :model
- memo << "\n\nAssistant: #{msg[:content]}"
- elsif msg[:type] == :tool
- memo << "\n\nHuman:\n"
- memo << tool_result_to_xml(msg)
- else
- memo << "\n\nHuman: "
- memo << "#{msg[:id]}: " if msg[:id]
- memo << msg[:content]
end
+ .compact
- memo
+ if prompt.tools.present?
+ system_prompt << "\n\n"
+ system_prompt << build_tools_prompt
+ end
+
+ interleving_messages = []
+
+ previous_message = nil
+ messages.each do |message|
+ if previous_message
+ if previous_message[:role] == "user" && message[:role] == "user"
+ interleving_messages << { role: "assistant", content: "OK" }
+ elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
+ interleving_messages << { role: "user", content: "OK" }
+ end
end
+ interleving_messages << message
+ previous_message = message
+ end
- claude_prompt << "\n\nAssistant:"
- claude_prompt << " #{last_message[:content]}:" if last_message
-
- claude_prompt
+ ClaudePrompt.new(system_prompt.presence, interleving_messages)
end
def max_prompt_tokens
- 100_000 # Claude-2.1 has a 200k context window.
- end
-
- private
-
- def uses_system_message?
- model_name == "claude-2"
+ # Longer term it will have over 1 million
+ 200_000 # Claude-3 has a 200k context window for now
end
end
end
diff --git a/lib/completions/dialects/claude_messages.rb b/lib/completions/dialects/claude_messages.rb
deleted file mode 100644
index c0e9feb0..00000000
--- a/lib/completions/dialects/claude_messages.rb
+++ /dev/null
@@ -1,85 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Completions
- module Dialects
- class ClaudeMessages < Dialect
- class << self
- def can_translate?(model_name)
- # TODO: add haiku not released yet as of 2024-03-05
- %w[claude-3-sonnet claude-3-opus].include?(model_name)
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::AnthropicTokenizer
- end
- end
-
- class ClaudePrompt
- attr_reader :system_prompt
- attr_reader :messages
-
- def initialize(system_prompt, messages)
- @system_prompt = system_prompt
- @messages = messages
- end
- end
-
- def translate
- messages = prompt.messages
- system_prompt = +""
-
- messages =
- trim_messages(messages)
- .map do |msg|
- case msg[:type]
- when :system
- system_prompt << msg[:content]
- nil
- when :tool_call
- { role: "assistant", content: tool_call_to_xml(msg) }
- when :tool
- { role: "user", content: tool_result_to_xml(msg) }
- when :model
- { role: "assistant", content: msg[:content] }
- when :user
- content = +""
- content << "#{msg[:id]}: " if msg[:id]
- content << msg[:content]
-
- { role: "user", content: content }
- end
- end
- .compact
-
- if prompt.tools.present?
- system_prompt << "\n\n"
- system_prompt << build_tools_prompt
- end
-
- interleving_messages = []
-
- previous_message = nil
- messages.each do |message|
- if previous_message
- if previous_message[:role] == "user" && message[:role] == "user"
- interleving_messages << { role: "assistant", content: "OK" }
- elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
- interleving_messages << { role: "user", content: "OK" }
- end
- end
- interleving_messages << message
- previous_message = message
- end
-
- ClaudePrompt.new(system_prompt.presence, interleving_messages)
- end
-
- def max_prompt_tokens
- # Longer term it will have over 1 million
- 200_000 # Claude-3 has a 200k context window for now
- end
- end
- end
- end
-end
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
index fb7b48da..3baaa68d 100644
--- a/lib/completions/dialects/dialect.rb
+++ b/lib/completions/dialects/dialect.rb
@@ -11,13 +11,12 @@ module DiscourseAi
def dialect_for(model_name)
dialects = [
- DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral,
- DiscourseAi::Completions::Dialects::ClaudeMessages,
+ DiscourseAi::Completions::Dialects::Claude,
]
if Rails.env.test? || Rails.env.development?
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 0f766c4d..8c27a269 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -6,7 +6,10 @@ module DiscourseAi
class Anthropic < Base
class << self
def can_contact?(endpoint_name, model_name)
- endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name)
+ endpoint_name == "anthropic" &&
+ %w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
+ model_name,
+ )
end
def dependant_setting_names
@@ -23,23 +26,32 @@ module DiscourseAi
end
def normalize_model_params(model_params)
- model_params = model_params.dup
-
- # temperature, stop_sequences are already supported
- #
- if model_params[:max_tokens]
- model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
- end
-
+ # max_tokens, temperature, stop_sequences are already supported
model_params
end
- def default_options
- {
- model: model == "claude-2" ? "claude-2.1" : model,
- max_tokens_to_sample: 3_000,
- stop_sequences: ["\n\nHuman:", ""],
- }
+ def default_options(dialect)
+ # skipping 2.0 support for now, since other models are better
+ mapped_model =
+ case model
+ when "claude-2"
+ "claude-2.1"
+ when "claude-instant-1"
+ "claude-instant-1.2"
+ when "claude-3-haiku"
+ "claude-3-haiku-20240307"
+ when "claude-3-sonnet"
+ "claude-3-sonnet-20240229"
+ when "claude-3-opus"
+ "claude-3-opus-20240229"
+ else
+ raise "Unsupported model: #{model}"
+ end
+
+ options = { model: mapped_model, max_tokens: 3_000 }
+
+ options[:stop_sequences] = [""] if dialect.prompt.has_tools?
+ options
end
def provider_id
@@ -48,15 +60,22 @@ module DiscourseAi
private
- def model_uri
- @uri ||= URI("https://api.anthropic.com/v1/complete")
+ # this is an approximation, we will update it later if request goes through
+ def prompt_size(prompt)
+ super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
- def prepare_payload(prompt, model_params, _dialect)
- default_options
- .merge(model_params)
- .merge(prompt: prompt)
- .tap { |payload| payload[:stream] = true if @streaming_mode }
+ def model_uri
+ @uri ||= URI("https://api.anthropic.com/v1/messages")
+ end
+
+ def prepare_payload(prompt, model_params, dialect)
+ payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
+
+ payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
+ payload[:stream] = true if @streaming_mode
+
+ payload
end
def prepare_request(payload)
@@ -69,8 +88,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
+ def final_log_update(log)
+ log.request_tokens = @input_tokens if @input_tokens
+ log.response_tokens = @output_tokens if @output_tokens
+ end
+
def extract_completion_from(response_raw)
- JSON.parse(response_raw, symbolize_names: true)[:completion].to_s
+ result = ""
+ parsed = JSON.parse(response_raw, symbolize_names: true)
+
+ if @streaming_mode
+ if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
+ result = parsed.dig(:delta, :text).to_s
+ elsif parsed[:type] == "message_start"
+ @input_tokens = parsed.dig(:message, :usage, :input_tokens)
+ elsif parsed[:type] == "message_delta"
+ @output_tokens = parsed.dig(:delta, :usage, :output_tokens)
+ end
+ else
+ result = parsed.dig(:content, 0, :text).to_s
+ @input_tokens = parsed.dig(:usage, :input_tokens)
+ @output_tokens = parsed.dig(:usage, :output_tokens)
+ end
+
+ result
end
def partials_from(decoded_chunk)
diff --git a/lib/completions/endpoints/anthropic_messages.rb b/lib/completions/endpoints/anthropic_messages.rb
deleted file mode 100644
index 7485e0e6..00000000
--- a/lib/completions/endpoints/anthropic_messages.rb
+++ /dev/null
@@ -1,103 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Completions
- module Endpoints
- class AnthropicMessages < Base
- class << self
- def can_contact?(endpoint_name, model_name)
- endpoint_name == "anthropic" && %w[claude-3-opus claude-3-sonnet].include?(model_name)
- end
-
- def dependant_setting_names
- %w[ai_anthropic_api_key]
- end
-
- def correctly_configured?(_model_name)
- SiteSetting.ai_anthropic_api_key.present?
- end
-
- def endpoint_name(model_name)
- "Anthropic - #{model_name}"
- end
- end
-
- def normalize_model_params(model_params)
- # max_tokens, temperature, stop_sequences are already supported
- model_params
- end
-
- def default_options(dialect)
- options = { model: model + "-20240229", max_tokens: 3_000 }
-
- options[:stop_sequences] = [""] if dialect.prompt.has_tools?
- options
- end
-
- def provider_id
- AiApiAuditLog::Provider::Anthropic
- end
-
- private
-
- # this is an approximation, we will update it later if request goes through
- def prompt_size(prompt)
- super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
- end
-
- def model_uri
- @uri ||= URI("https://api.anthropic.com/v1/messages")
- end
-
- def prepare_payload(prompt, model_params, dialect)
- payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
-
- payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
- payload[:stream] = true if @streaming_mode
-
- payload
- end
-
- def prepare_request(payload)
- headers = {
- "anthropic-version" => "2023-06-01",
- "x-api-key" => SiteSetting.ai_anthropic_api_key,
- "content-type" => "application/json",
- }
-
- Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
- end
-
- def final_log_update(log)
- log.request_tokens = @input_tokens if @input_tokens
- log.response_tokens = @output_tokens if @output_tokens
- end
-
- def extract_completion_from(response_raw)
- result = ""
- parsed = JSON.parse(response_raw, symbolize_names: true)
-
- if @streaming_mode
- if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
- result = parsed.dig(:delta, :text).to_s
- elsif parsed[:type] == "message_start"
- @input_tokens = parsed.dig(:message, :usage, :input_tokens)
- elsif parsed[:type] == "message_delta"
- @output_tokens = parsed.dig(:delta, :usage, :output_tokens)
- end
- else
- result = parsed.dig(:content, 0, :text).to_s
- @input_tokens = parsed.dig(:usage, :input_tokens)
- @output_tokens = parsed.dig(:usage, :output_tokens)
- end
-
- result
- end
-
- def partials_from(decoded_chunk)
- decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
- end
- end
- end
- end
-end
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 92c5e16f..d62cfe41 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -8,17 +8,18 @@ module DiscourseAi
class AwsBedrock < Base
class << self
def can_contact?(endpoint_name, model_name)
- endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name)
+ endpoint_name == "aws_bedrock" &&
+ %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet].include?(model_name)
end
def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end
- def correctly_configured?(_model_name)
+ def correctly_configured?(model)
SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
- SiteSetting.ai_bedrock_region.present?
+ SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
end
def endpoint_name(model_name)
@@ -29,17 +30,15 @@ module DiscourseAi
def normalize_model_params(model_params)
model_params = model_params.dup
- # temperature, stop_sequences are already supported
- #
- if model_params[:max_tokens]
- model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
- end
+ # max_tokens, temperature, stop_sequences, top_p are already supported
model_params
end
- def default_options
- { max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", ""] }
+ def default_options(dialect)
+ options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
+ options[:stop_sequences] = [""] if dialect.prompt.has_tools?
+ options
end
def provider_id
@@ -48,25 +47,40 @@ module DiscourseAi
private
- def model_uri
- # Bedrock uses slightly different names
- # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
- bedrock_model_id = model.split("-")
- bedrock_model_id[-1] = "v#{bedrock_model_id.last}"
- bedrock_model_id = +(bedrock_model_id.join("-"))
+ def prompt_size(prompt)
+ # approximation
+ super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
+ end
- bedrock_model_id << ":1" if model == "claude-2" # For claude-2.1
+ def model_uri
+ # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
+ #
+ # FYI there is a 2.0 version of Claude, very little need to support it given
+ # haiku/sonnet are better fits anyway, we map to claude-2.1
+ bedrock_model_id =
+ case model
+ when "claude-2"
+ "anthropic.claude-v2:1"
+ when "claude-3-haiku"
+ "anthropic.claude-3-haiku-20240307-v1:0"
+ when "claude-3-sonnet"
+ "anthropic.claude-3-sonnet-20240229-v1:0"
+ when "claude-instant-1"
+ "anthropic.claude-instant-v1"
+ end
api_url =
- "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_model_id}/invoke"
+ "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
URI(api_url)
end
- def prepare_payload(prompt, model_params, _dialect)
- default_options.merge(prompt: prompt).merge(model_params)
+ def prepare_payload(prompt, model_params, dialect)
+ payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
+ payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
+ payload
end
def prepare_request(payload)
@@ -117,8 +131,30 @@ module DiscourseAi
nil
end
+ def final_log_update(log)
+ log.request_tokens = @input_tokens if @input_tokens
+ log.response_tokens = @output_tokens if @output_tokens
+ end
+
def extract_completion_from(response_raw)
- JSON.parse(response_raw, symbolize_names: true)[:completion].to_s
+ result = ""
+ parsed = JSON.parse(response_raw, symbolize_names: true)
+
+ if @streaming_mode
+ if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
+ result = parsed.dig(:delta, :text).to_s
+ elsif parsed[:type] == "message_start"
+ @input_tokens = parsed.dig(:message, :usage, :input_tokens)
+ elsif parsed[:type] == "message_delta"
+ @output_tokens = parsed.dig(:delta, :usage, :output_tokens)
+ end
+ else
+ result = parsed.dig(:content, 0, :text).to_s
+ @input_tokens = parsed.dig(:usage, :input_tokens)
+ @output_tokens = parsed.dig(:usage, :output_tokens)
+ end
+
+ result
end
def partials_from(decoded_chunk)
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index ec813cab..51c8c012 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -11,12 +11,11 @@ module DiscourseAi
def endpoint_for(provider_name, model_name)
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
- DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
- DiscourseAi::Completions::Endpoints::AnthropicMessages,
+ DiscourseAi::Completions::Endpoints::Anthropic,
]
if Rails.env.test? || Rails.env.development?
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index ba4bdf34..ed1b216b 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -23,8 +23,8 @@ module DiscourseAi
# However, since they use the same URL/key settings, there's no reason to duplicate them.
@models_by_provider ||=
{
- aws_bedrock: %w[claude-instant-1 claude-2],
- anthropic: %w[claude-instant-1 claude-2 claude-3-sonnet claude-3-opus],
+ aws_bedrock: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet],
+ anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
diff --git a/spec/lib/completions/dialects/claude_messages_spec.rb b/spec/lib/completions/dialects/claude_messages_spec.rb
deleted file mode 100644
index 7e1f1edc..00000000
--- a/spec/lib/completions/dialects/claude_messages_spec.rb
+++ /dev/null
@@ -1,87 +0,0 @@
-# frozen_string_literal: true
-
-RSpec.describe DiscourseAi::Completions::Dialects::ClaudeMessages do
- describe "#translate" do
- it "can insert OKs to make stuff interleve properly" do
- messages = [
- { type: :user, id: "user1", content: "1" },
- { type: :model, content: "2" },
- { type: :user, id: "user1", content: "4" },
- { type: :user, id: "user1", content: "5" },
- { type: :model, content: "6" },
- ]
-
- prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
-
- dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
- dialect = dialectKlass.new(prompt, "claude-3-opus")
- translated = dialect.translate
-
- expected_messages = [
- { role: "user", content: "user1: 1" },
- { role: "assistant", content: "2" },
- { role: "user", content: "user1: 4" },
- { role: "assistant", content: "OK" },
- { role: "user", content: "user1: 5" },
- { role: "assistant", content: "6" },
- ]
-
- expect(translated.messages).to eq(expected_messages)
- end
-
- it "can properly translate a prompt" do
- dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
-
- tools = [
- {
- name: "echo",
- description: "echo a string",
- parameters: [
- { name: "text", type: "string", description: "string to echo", required: true },
- ],
- },
- ]
-
- tool_call_prompt = { name: "echo", arguments: { text: "something" } }
-
- messages = [
- { type: :user, id: "user1", content: "echo something" },
- { type: :tool_call, content: tool_call_prompt.to_json },
- { type: :tool, id: "tool_id", content: "something".to_json },
- { type: :model, content: "I did it" },
- { type: :user, id: "user1", content: "echo something else" },
- ]
-
- prompt =
- DiscourseAi::Completions::Prompt.new(
- "You are a helpful bot",
- messages: messages,
- tools: tools,
- )
-
- dialect = dialect.new(prompt, "claude-3-opus")
- translated = dialect.translate
-
- expect(translated.system_prompt).to start_with("You are a helpful bot")
- expect(translated.system_prompt).to include("echo a string")
-
- expected = [
- { role: "user", content: "user1: echo something" },
- {
- role: "assistant",
- content:
- "\n\necho\n\nsomething\n\n\n",
- },
- {
- role: "user",
- content:
- "\n\ntool_id\n\n\"something\"\n\n\n",
- },
- { role: "assistant", content: "I did it" },
- { role: "user", content: "user1: echo something else" },
- ]
-
- expect(translated.messages).to eq(expected)
- end
- end
-end
diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb
index efb62bb3..af082c22 100644
--- a/spec/lib/completions/dialects/claude_spec.rb
+++ b/spec/lib/completions/dialects/claude_spec.rb
@@ -1,100 +1,87 @@
# frozen_string_literal: true
-require_relative "dialect_context"
-
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
- let(:model_name) { "claude-2" }
- let(:context) { DialectContext.new(described_class, model_name) }
-
describe "#translate" do
- it "translates a prompt written in our generic format to Claude's format" do
- anthropic_version = (<<~TEXT).strip
- #{context.system_insts}
- #{described_class.tool_preamble}
-
- #{context.dialect_tools}
+ it "can insert OKs to make stuff interleve properly" do
+ messages = [
+ { type: :user, id: "user1", content: "1" },
+ { type: :model, content: "2" },
+ { type: :user, id: "user1", content: "4" },
+ { type: :user, id: "user1", content: "5" },
+ { type: :model, content: "6" },
+ ]
- Human: #{context.simple_user_input}
+ prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
- Assistant:
- TEXT
+ dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
+ dialect = dialectKlass.new(prompt, "claude-3-opus")
+ translated = dialect.translate
- translated = context.system_user_scenario
+ expected_messages = [
+ { role: "user", content: "user1: 1" },
+ { role: "assistant", content: "2" },
+ { role: "user", content: "user1: 4" },
+ { role: "assistant", content: "OK" },
+ { role: "user", content: "user1: 5" },
+ { role: "assistant", content: "6" },
+ ]
- expect(translated).to eq(anthropic_version)
+ expect(translated.messages).to eq(expected_messages)
end
- it "translates tool messages" do
- expected = +(<<~TEXT).strip
- #{context.system_insts}
- #{described_class.tool_preamble}
-
- #{context.dialect_tools}
+ it "can properly translate a prompt" do
+ dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
- Human: user1: This is a message by a user
+ tools = [
+ {
+ name: "echo",
+ description: "echo a string",
+ parameters: [
+ { name: "text", type: "string", description: "string to echo", required: true },
+ ],
+ },
+ ]
- Assistant: I'm a previous bot reply, that's why there's no user
+ tool_call_prompt = { name: "echo", arguments: { text: "something" } }
- Human: user1: This is a new message by a user
+ messages = [
+ { type: :user, id: "user1", content: "echo something" },
+ { type: :tool_call, content: tool_call_prompt.to_json },
+ { type: :tool, id: "tool_id", content: "something".to_json },
+ { type: :model, content: "I did it" },
+ { type: :user, id: "user1", content: "echo something else" },
+ ]
- Assistant:
-
- get_weather
-
- Sydney
- c
-
-
-
-
- Human:
-
-
- get_weather
-
- "I'm a tool result"
-
-
-
-
- Assistant:
- TEXT
-
- expect(context.multi_turn_scenario).to eq(expected)
- end
-
- it "trims content if it's getting too long" do
- length = 19_000
-
- translated = context.long_user_input_scenario(length: length)
-
- expect(translated.length).to be < context.long_message_text(length: length).length
- end
-
- it "retains usernames in generated prompt" do
prompt =
DiscourseAi::Completions::Prompt.new(
- "You are a bot",
- messages: [
- { id: "👻", type: :user, content: "Message1" },
- { type: :model, content: "Ok" },
- { id: "joe", type: :user, content: "Message2" },
- ],
+ "You are a helpful bot",
+ messages: messages,
+ tools: tools,
)
- translated = context.dialect(prompt).translate
+ dialect = dialect.new(prompt, "claude-3-opus")
+ translated = dialect.translate
- expect(translated).to eq(<<~TEXT.strip)
- You are a bot
+ expect(translated.system_prompt).to start_with("You are a helpful bot")
+ expect(translated.system_prompt).to include("echo a string")
- Human: 👻: Message1
+ expected = [
+ { role: "user", content: "user1: echo something" },
+ {
+ role: "assistant",
+ content:
+ "\n\necho\n\nsomething\n\n\n",
+ },
+ {
+ role: "user",
+ content:
+ "\n\ntool_id\n\n\"something\"\n\n\n",
+ },
+ { role: "assistant", content: "I did it" },
+ { role: "user", content: "user1: echo something else" },
+ ]
- Assistant: Ok
-
- Human: joe: Message2
-
- Assistant:
- TEXT
+ expect(translated.messages).to eq(expected)
end
end
end
diff --git a/spec/lib/completions/endpoints/anthropic_messages_spec.rb b/spec/lib/completions/endpoints/anthropic_messages_spec.rb
deleted file mode 100644
index 1d52e46c..00000000
--- a/spec/lib/completions/endpoints/anthropic_messages_spec.rb
+++ /dev/null
@@ -1,395 +0,0 @@
-# frozen_string_literal: true
-
-RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
- let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
-
- let(:prompt) do
- DiscourseAi::Completions::Prompt.new(
- "You are hello bot",
- messages: [type: :user, id: "user1", content: "hello"],
- )
- end
-
- let(:echo_tool) do
- {
- name: "echo",
- description: "echo something",
- parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
- }
- end
-
- let(:google_tool) do
- {
- name: "google",
- description: "google something",
- parameters: [
- { name: "query", type: "string", description: "text to google", required: true },
- ],
- }
- end
-
- let(:prompt_with_echo_tool) do
- prompt_with_tools = prompt
- prompt.tools = [echo_tool]
- prompt_with_tools
- end
-
- let(:prompt_with_google_tool) do
- prompt_with_tools = prompt
- prompt.tools = [echo_tool]
- prompt_with_tools
- end
-
- before { SiteSetting.ai_anthropic_api_key = "123" }
-
- it "does not eat spaces with tool calls" do
- body = <<~STRING
- event: message_start
- data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
-
- event: content_block_start
- data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
-
- event: ping
- data: {"type": "ping"}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_stop
- data: {"type":"content_block_stop","index":0}
-
- event: message_delta
- data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}}
-
- event: message_stop
- data: {"type":"message_stop"}
- STRING
-
- stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
-
- result = +""
- llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
- result << partial
- end
-
- expected = (<<~TEXT).strip
-
-
- google
-
- top 10 things to do in japan for tourists
-
- tool_0
-
-
- TEXT
-
- expect(result.strip).to eq(expected)
- end
-
- it "can stream a response" do
- body = (<<~STRING).strip
- event: message_start
- data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
-
- event: content_block_start
- data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
-
- event: ping
- data: {"type": "ping"}
-
- event: content_block_delta
- data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
-
- event: content_block_delta
- data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
-
- event: content_block_stop
- data: {"type": "content_block_stop", "index": 0}
-
- event: message_delta
- data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
-
- event: message_stop
- data: {"type": "message_stop"}
- STRING
-
- parsed_body = nil
-
- stub_request(:post, "https://api.anthropic.com/v1/messages").with(
- body:
- proc do |req_body|
- parsed_body = JSON.parse(req_body, symbolize_names: true)
- true
- end,
- headers: {
- "Content-Type" => "application/json",
- "X-Api-Key" => "123",
- "Anthropic-Version" => "2023-06-01",
- },
- ).to_return(status: 200, body: body)
-
- result = +""
- llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
-
- expect(result).to eq("Hello!")
-
- expected_body = {
- model: "claude-3-opus-20240229",
- max_tokens: 3000,
- messages: [{ role: "user", content: "user1: hello" }],
- system: "You are hello bot",
- stream: true,
- }
- expect(parsed_body).to eq(expected_body)
-
- log = AiApiAuditLog.order(:id).last
- expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
- expect(log.request_tokens).to eq(25)
- expect(log.response_tokens).to eq(15)
- end
-
- it "can return multiple function calls" do
- functions = <<~FUNCTIONS
-
-
- echo
-
- something
-
-
-
- echo
-
- something else
-
-
- FUNCTIONS
-
- body = <<~STRING
- {
- "content": [
- {
- "text": "Hello!\n\n#{functions}\njunk",
- "type": "text"
- }
- ],
- "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
- "model": "claude-3-opus-20240229",
- "role": "assistant",
- "stop_reason": "end_turn",
- "stop_sequence": null,
- "type": "message",
- "usage": {
- "input_tokens": 10,
- "output_tokens": 25
- }
- }
- STRING
-
- stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
-
- result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
-
- expected = (<<~EXPECTED).strip
-
-
- echo
-
- something
-
- tool_0
-
-
- echo
-
- something else
-
- tool_1
-
-
- EXPECTED
-
- expect(result.strip).to eq(expected)
- end
-
- it "can operate in regular mode" do
- body = <<~STRING
- {
- "content": [
- {
- "text": "Hello!",
- "type": "text"
- }
- ],
- "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
- "model": "claude-3-opus-20240229",
- "role": "assistant",
- "stop_reason": "end_turn",
- "stop_sequence": null,
- "type": "message",
- "usage": {
- "input_tokens": 10,
- "output_tokens": 25
- }
- }
- STRING
-
- parsed_body = nil
- stub_request(:post, "https://api.anthropic.com/v1/messages").with(
- body:
- proc do |req_body|
- parsed_body = JSON.parse(req_body, symbolize_names: true)
- true
- end,
- headers: {
- "Content-Type" => "application/json",
- "X-Api-Key" => "123",
- "Anthropic-Version" => "2023-06-01",
- },
- ).to_return(status: 200, body: body)
-
- result = llm.generate(prompt, user: Discourse.system_user)
- expect(result).to eq("Hello!")
-
- expected_body = {
- model: "claude-3-opus-20240229",
- max_tokens: 3000,
- messages: [{ role: "user", content: "user1: hello" }],
- system: "You are hello bot",
- }
- expect(parsed_body).to eq(expected_body)
-
- log = AiApiAuditLog.order(:id).last
- expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
- expect(log.request_tokens).to eq(10)
- expect(log.response_tokens).to eq(25)
- end
-end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index b05b79b0..d34acbf5 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -1,96 +1,395 @@
-# frozen_String_literal: true
+# frozen_string_literal: true
-require_relative "endpoint_compliance"
+RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
-class AnthropicMock < EndpointMock
- def response(content)
+ let(:prompt) do
+ DiscourseAi::Completions::Prompt.new(
+ "You are hello bot",
+ messages: [type: :user, id: "user1", content: "hello"],
+ )
+ end
+
+ let(:echo_tool) do
{
- completion: content,
- stop: "\n\nHuman:",
- stop_reason: "stop_sequence",
- truncated: false,
- log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
- model: "claude-2",
- exception: nil,
+ name: "echo",
+ description: "echo something",
+ parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
}
end
- def stub_response(prompt, response_text, tool_call: false)
- WebMock
- .stub_request(:post, "https://api.anthropic.com/v1/complete")
- .with(body: model.default_options.merge(prompt: prompt).to_json)
- .to_return(status: 200, body: JSON.dump(response(response_text)))
+ let(:google_tool) do
+ {
+ name: "google",
+ description: "google something",
+ parameters: [
+ { name: "query", type: "string", description: "text to google", required: true },
+ ],
+ }
end
- def stream_line(delta, finish_reason: nil)
- +"data: " << {
- completion: delta,
- stop: finish_reason ? "\n\nHuman:" : nil,
- stop_reason: finish_reason,
- truncated: false,
- log_id: "12b029451c6d18094d868bc04ce83f63",
- model: "claude-2",
- exception: nil,
- }.to_json
+ let(:prompt_with_echo_tool) do
+ prompt_with_tools = prompt
+ prompt.tools = [echo_tool]
+ prompt_with_tools
end
- def stub_streamed_response(prompt, deltas, tool_call: false)
- chunks =
- deltas.each_with_index.map do |_, index|
- if index == (deltas.length - 1)
- stream_line(deltas[index], finish_reason: "stop_sequence")
- else
- stream_line(deltas[index])
- end
- end
-
- chunks = chunks.join("\n\n").split("")
-
- WebMock
- .stub_request(:post, "https://api.anthropic.com/v1/complete")
- .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
- .to_return(status: 200, body: chunks)
- end
-end
-
-RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
- subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
-
- fab!(:user)
-
- let(:anthropic_mock) { AnthropicMock.new(endpoint) }
-
- let(:compliance) do
- EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
+ let(:prompt_with_google_tool) do
+ prompt_with_tools = prompt
+ prompt.tools = [echo_tool]
+ prompt_with_tools
end
- describe "#perform_completion!" do
- context "when using regular mode" do
- context "with simple prompts" do
- it "completes a trivial prompt and logs the response" do
- compliance.regular_mode_simple_prompt(anthropic_mock)
- end
- end
+ before { SiteSetting.ai_anthropic_api_key = "123" }
- context "with tools" do
- it "returns a function invocation" do
- compliance.regular_mode_tools(anthropic_mock)
- end
- end
+ it "does not eat spaces with tool calls" do
+ body = <<~STRING
+ event: message_start
+ data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
+
+ event: content_block_start
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
+
+ event: ping
+ data: {"type": "ping"}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+
+ event: content_block_stop
+ data: {"type":"content_block_stop","index":0}
+
+ event: message_delta
+ data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}}
+
+ event: message_stop
+ data: {"type":"message_stop"}
+ STRING
+
+ stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
+
+ result = +""
+ llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
+ result << partial
end
- describe "when using streaming mode" do
- context "with simple prompts" do
- it "completes a trivial prompt and logs the response" do
- compliance.streaming_mode_simple_prompt(anthropic_mock)
- end
- end
+ expected = (<<~TEXT).strip
+
+
+ google
+
+ top 10 things to do in japan for tourists
+
+ tool_0
+
+
+ TEXT
- context "with tools" do
- it "returns a function invocation" do
- compliance.streaming_mode_tools(anthropic_mock)
- end
- end
- end
+ expect(result.strip).to eq(expected)
+ end
+
+ it "can stream a response" do
+ body = (<<~STRING).strip
+ event: message_start
+ data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
+
+ event: content_block_start
+ data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
+
+ event: ping
+ data: {"type": "ping"}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
+
+ event: content_block_delta
+ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
+
+ event: content_block_stop
+ data: {"type": "content_block_stop", "index": 0}
+
+ event: message_delta
+ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
+
+ event: message_stop
+ data: {"type": "message_stop"}
+ STRING
+
+ parsed_body = nil
+
+ stub_request(:post, "https://api.anthropic.com/v1/messages").with(
+ body:
+ proc do |req_body|
+ parsed_body = JSON.parse(req_body, symbolize_names: true)
+ true
+ end,
+ headers: {
+ "Content-Type" => "application/json",
+ "X-Api-Key" => "123",
+ "Anthropic-Version" => "2023-06-01",
+ },
+ ).to_return(status: 200, body: body)
+
+ result = +""
+ llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
+
+ expect(result).to eq("Hello!")
+
+ expected_body = {
+ model: "claude-3-opus-20240229",
+ max_tokens: 3000,
+ messages: [{ role: "user", content: "user1: hello" }],
+ system: "You are hello bot",
+ stream: true,
+ }
+ expect(parsed_body).to eq(expected_body)
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
+ expect(log.request_tokens).to eq(25)
+ expect(log.response_tokens).to eq(15)
+ end
+
+ it "can return multiple function calls" do
+ functions = <<~FUNCTIONS
+
+
+ echo
+
+ something
+
+
+
+ echo
+
+ something else
+
+
+ FUNCTIONS
+
+ body = <<~STRING
+ {
+ "content": [
+ {
+ "text": "Hello!\n\n#{functions}\njunk",
+ "type": "text"
+ }
+ ],
+ "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
+ "model": "claude-3-opus-20240229",
+ "role": "assistant",
+ "stop_reason": "end_turn",
+ "stop_sequence": null,
+ "type": "message",
+ "usage": {
+ "input_tokens": 10,
+ "output_tokens": 25
+ }
+ }
+ STRING
+
+ stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
+
+ result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
+
+ expected = (<<~EXPECTED).strip
+
+
+ echo
+
+ something
+
+ tool_0
+
+
+ echo
+
+ something else
+
+ tool_1
+
+
+ EXPECTED
+
+ expect(result.strip).to eq(expected)
+ end
+
+ it "can operate in regular mode" do
+ body = <<~STRING
+ {
+ "content": [
+ {
+ "text": "Hello!",
+ "type": "text"
+ }
+ ],
+ "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
+ "model": "claude-3-opus-20240229",
+ "role": "assistant",
+ "stop_reason": "end_turn",
+ "stop_sequence": null,
+ "type": "message",
+ "usage": {
+ "input_tokens": 10,
+ "output_tokens": 25
+ }
+ }
+ STRING
+
+ parsed_body = nil
+ stub_request(:post, "https://api.anthropic.com/v1/messages").with(
+ body:
+ proc do |req_body|
+ parsed_body = JSON.parse(req_body, symbolize_names: true)
+ true
+ end,
+ headers: {
+ "Content-Type" => "application/json",
+ "X-Api-Key" => "123",
+ "Anthropic-Version" => "2023-06-01",
+ },
+ ).to_return(status: 200, body: body)
+
+ result = llm.generate(prompt, user: Discourse.system_user)
+ expect(result).to eq("Hello!")
+
+ expected_body = {
+ model: "claude-3-opus-20240229",
+ max_tokens: 3000,
+ messages: [{ role: "user", content: "user1: hello" }],
+ system: "You are hello bot",
+ }
+ expect(parsed_body).to eq(expected_body)
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
+ expect(log.request_tokens).to eq(10)
+ expect(log.response_tokens).to eq(25)
end
end
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index 2461941b..7d3554cd 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -5,71 +5,6 @@ require "aws-eventstream"
require "aws-sigv4"
class BedrockMock < EndpointMock
- def response(content)
- {
- completion: content,
- stop: "\n\nHuman:",
- stop_reason: "stop_sequence",
- truncated: false,
- log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
- model: "claude",
- exception: nil,
- }
- end
-
- def stub_response(prompt, response_content, tool_call: false)
- WebMock
- .stub_request(:post, "#{base_url}/invoke")
- .with(body: model.default_options.merge(prompt: prompt).to_json)
- .to_return(status: 200, body: JSON.dump(response(response_content)))
- end
-
- def stream_line(delta, finish_reason: nil)
- encoder = Aws::EventStream::Encoder.new
-
- message =
- Aws::EventStream::Message.new(
- payload:
- StringIO.new(
- {
- bytes:
- Base64.encode64(
- {
- completion: delta,
- stop: finish_reason ? "\n\nHuman:" : nil,
- stop_reason: finish_reason,
- truncated: false,
- log_id: "12b029451c6d18094d868bc04ce83f63",
- model: "claude-2.1",
- exception: nil,
- }.to_json,
- ),
- }.to_json,
- ),
- )
-
- encoder.encode(message)
- end
-
- def stub_streamed_response(prompt, deltas, tool_call: false)
- chunks =
- deltas.each_with_index.map do |_, index|
- if index == (deltas.length - 1)
- stream_line(deltas[index], finish_reason: "stop_sequence")
- else
- stream_line(deltas[index])
- end
- end
-
- WebMock
- .stub_request(:post, "#{base_url}/invoke-with-response-stream")
- .with(body: model.default_options.merge(prompt: prompt).to_json)
- .to_return(status: 200, body: chunks)
- end
-
- def base_url
- "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
- end
end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
@@ -89,32 +24,98 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1"
end
- describe "#perform_completion!" do
- context "when using regular mode" do
- context "with simple prompts" do
- it "completes a trivial prompt and logs the response" do
- compliance.regular_mode_simple_prompt(bedrock_mock)
- end
- end
+ describe "Claude 3 Sonnet support" do
+ it "supports the sonnet model" do
+ proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
- context "with tools" do
- it "returns a function invocation" do
- compliance.regular_mode_tools(bedrock_mock)
+ request = nil
+
+ content = {
+ content: [text: "hello sam"],
+ usage: {
+ input_tokens: 10,
+ output_tokens: 20,
+ },
+ }.to_json
+
+ stub_request(
+ :post,
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
+ )
+ .with do |inner_request|
+ request = inner_request
+ true
end
- end
+ .to_return(status: 200, body: content)
+
+ response = proxy.generate("hello world", user: user)
+
+ expect(request.headers["Authorization"]).to be_present
+ expect(request.headers["X-Amz-Content-Sha256"]).to be_present
+
+ expected = {
+ "max_tokens" => 3000,
+ "anthropic_version" => "bedrock-2023-05-31",
+ "messages" => [{ "role" => "user", "content" => "hello world" }],
+ "system" => "You are a helpful bot",
+ }
+ expect(JSON.parse(request.body)).to eq(expected)
+
+ expect(response).to eq("hello sam")
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(10)
+ expect(log.response_tokens).to eq(20)
end
- describe "when using streaming mode" do
- context "with simple prompts" do
- it "completes a trivial prompt and logs the response" do
- compliance.streaming_mode_simple_prompt(bedrock_mock)
- end
- end
+ it "supports claude 3 sonnet streaming" do
+ proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
- context "with tools" do
- it "returns a function invocation" do
- compliance.streaming_mode_tools(bedrock_mock)
+ request = nil
+
+ messages =
+ [
+ { type: "message_start", message: { usage: { input_tokens: 9 } } },
+ { type: "content_block_delta", delta: { text: "hello " } },
+ { type: "content_block_delta", delta: { text: "sam" } },
+ { type: "message_delta", delta: { usage: { output_tokens: 25 } } },
+ ].map do |message|
+ wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
+ io = StringIO.new(wrapped)
+ aws_message = Aws::EventStream::Message.new(payload: io)
+ Aws::EventStream::Encoder.new.encode(aws_message)
end
+
+ bedrock_mock.with_chunk_array_support do
+ stub_request(
+ :post,
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
+ )
+ .with do |inner_request|
+ request = inner_request
+ true
+ end
+ .to_return(status: 200, body: messages)
+
+ response = +""
+ proxy.generate("hello world", user: user) { |partial| response << partial }
+
+ expect(request.headers["Authorization"]).to be_present
+ expect(request.headers["X-Amz-Content-Sha256"]).to be_present
+
+ expected = {
+ "max_tokens" => 3000,
+ "anthropic_version" => "bedrock-2023-05-31",
+ "messages" => [{ "role" => "user", "content" => "hello world" }],
+ "system" => "You are a helpful bot",
+ }
+ expect(JSON.parse(request.body)).to eq(expected)
+
+ expect(response).to eq("hello sam")
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(9)
+ expect(log.response_tokens).to eq(25)
end
end
end