FIX: switch off native tools on Anthropic Claude Opus (#659)

Native tools do not work well on Opus.

Chain of Thought prompting means it consumes enormous amounts of
tokens and has poor latency.

This commit introduce and XML stripper to remove various chain of
thought XML islands from anthropic prompts when tools are involved.

This mean Opus native tools is now functions (albeit slowly)

From local testing XML just works better now.

Also fixes enum support in Anthropic native tools
This commit is contained in:
Sam 2024-06-07 23:52:01 +10:00 committed by GitHub
parent 7a64699314
commit 8b81ff45b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 439 additions and 66 deletions

View File

@ -51,6 +51,7 @@ en:
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
ai_openai_api_key: "API key for OpenAI API"
ai_anthropic_api_key: "API key for Anthropic API"
ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools."
ai_cohere_api_key: "API key for Cohere API"
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
ai_hugging_face_api_key: API key for Hugging Face API

View File

@ -111,6 +111,15 @@ discourse_ai:
ai_anthropic_api_key:
default: ""
secret: true
ai_anthropic_native_tool_call_models:
type: list
list_type: compact
default: "claude-3-sonnet|claude-3-haiku"
allow_any: false
choices:
- claude-3-opus
- claude-3-sonnet
- claude-3-haiku
ai_cohere_api_key:
default: ""
secret: true

View File

@ -22,6 +22,10 @@ module DiscourseAi
@messages = messages
@tools = tools
end
def has_tools?
tools.present?
end
end
def tokenizer
@ -33,6 +37,10 @@ module DiscourseAi
system_prompt = messages.shift[:content] if messages.first[:role] == "system"
if !system_prompt && !native_tool_support?
system_prompt = tools_dialect.instructions.presence
end
interleving_messages = []
previous_message = nil
@ -48,11 +56,10 @@ module DiscourseAi
previous_message = message
end
ClaudePrompt.new(
system_prompt.presence,
interleving_messages,
tools_dialect.translated_tools,
)
tools = nil
tools = tools_dialect.translated_tools if native_tool_support?
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
end
def max_prompt_tokens
@ -62,18 +69,28 @@ module DiscourseAi
200_000 # Claude-3 has a 200k context window for now
end
def native_tool_support?
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
end
private
def tools_dialect
if native_tool_support?
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
else
super
end
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
translated = tools_dialect.from_raw_tool_call(msg)
{ role: "assistant", content: translated }
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
translated = tools_dialect.from_raw_tool(msg)
{ role: "user", content: translated }
end
def model_msg(msg)

View File

@ -15,12 +15,13 @@ module DiscourseAi
required = []
if t[:parameters]
properties =
t[:parameters].each_with_object({}) do |param, h|
h[param[:name]] = {
type: param[:type],
description: param[:description],
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
properties = {}
t[:parameters].each do |param|
mapped = { type: param[:type], description: param[:description] }
mapped[:items] = { type: param[:item_type] } if param[:item_type]
mapped[:enum] = param[:enum] if param[:enum]
properties[param[:name]] = mapped
end
required =
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
@ -39,37 +40,24 @@ module DiscourseAi
end
def instructions
"" # Noop. Tools are listed separate.
""
end
def from_raw_tool_call(raw_message)
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
tool_call_id = raw_message[:id]
{
role: "assistant",
content: [
[
{
type: "tool_use",
id: tool_call_id,
name: raw_message[:name],
input: call_details[:arguments],
},
],
}
]
end
def from_raw_tool(raw_message)
{
role: "user",
content: [
{
type: "tool_result",
tool_use_id: raw_message[:id],
content: raw_message[:content],
},
],
}
[{ type: "tool_result", tool_use_id: raw_message[:id], content: raw_message[:content] }]
end
private

View File

@ -41,7 +41,10 @@ module DiscourseAi
def instructions
return "" if raw_tools.blank?
has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
@instructions ||=
begin
has_arrays =
raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
(<<~TEXT).strip
#{tool_preamble(include_array_tip: has_arrays)}
@ -49,6 +52,7 @@ module DiscourseAi
#{translated_tools}</tools>
TEXT
end
end
def from_raw_tool(raw_message)
(<<~TEXT).strip

View File

@ -45,7 +45,12 @@ module DiscourseAi
raise "Unsupported model: #{model}"
end
{ model: mapped_model, max_tokens: 3_000 }
options = { model: mapped_model, max_tokens: 3_000 }
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
dialect.prompt.has_tools?
options
end
def provider_id
@ -54,6 +59,14 @@ module DiscourseAi
private
def xml_tags_to_strip(dialect)
if dialect.prompt.has_tools?
%w[thinking search_quality_reflection search_quality_score]
else
[]
end
end
# this is an approximation, we will update it later if request goes through
def prompt_size(prompt)
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
@ -66,11 +79,13 @@ module DiscourseAi
end
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
payload[:tools] = prompt.tools if prompt.tools.present?
payload[:tools] = prompt.tools if prompt.has_tools?
payload
end
@ -108,7 +123,7 @@ module DiscourseAi
end
def native_tool_support?
true
@native_tool_support
end
def partials_from(decoded_chunk)

View File

@ -36,6 +36,9 @@ module DiscourseAi
def default_options(dialect)
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
dialect.prompt.has_tools?
options
end
@ -43,6 +46,14 @@ module DiscourseAi
AiApiAuditLog::Provider::Anthropic
end
def xml_tags_to_strip(dialect)
if dialect.prompt.has_tools?
%w[thinking search_quality_reflection search_quality_score]
else
[]
end
end
private
def prompt_size(prompt)
@ -79,9 +90,11 @@ module DiscourseAi
end
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:tools] = prompt.tools if prompt.tools.present?
payload[:tools] = prompt.tools if prompt.has_tools?
payload
end
@ -169,7 +182,7 @@ module DiscourseAi
end
def native_tool_support?
true
@native_tool_support
end
def chunk_to_string(chunk)

View File

@ -78,11 +78,27 @@ module DiscourseAi
end
end
def xml_tags_to_strip(dialect)
[]
end
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
orig_blk = blk
@streaming_mode = block_given?
to_strip = xml_tags_to_strip(dialect)
@xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && @xml_stripper
blk =
lambda do |partial, cancel|
partial = @xml_stripper << partial
orig_blk.call(partial, cancel) if partial
end
end
prompt = dialect.translate
@ -270,6 +286,11 @@ module DiscourseAi
blk.call(function_calls, cancel)
end
if @xml_stripper
leftover = @xml_stripper.finish
orig_blk.call(leftover, cancel) if leftover.present?
end
return response_data
ensure
if log

View File

@ -0,0 +1,115 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class XmlTagStripper
def initialize(tags_to_strip)
@tags_to_strip = tags_to_strip
@longest_tag = tags_to_strip.map(&:length).max
@parsed = []
end
def <<(text)
if node = @parsed[-1]
if node[:type] == :maybe_tag
@parsed.pop
text = node[:content] + text
end
end
@parsed.concat(parse_tags(text))
@parsed, result = process_parsed(@parsed)
result
end
def finish
@parsed.map { |node| node[:content] }.join
end
def process_parsed(parsed)
output = []
buffer = []
stack = []
parsed.each do |node|
case node[:type]
when :text
if stack.empty?
output << node[:content]
else
buffer << node
end
when :open_tag
stack << node[:name]
buffer << node
when :close_tag
if stack.empty?
output << node[:content]
else
if stack[0] == node[:name]
buffer = []
stack = []
else
buffer << node
end
end
when :maybe_tag
buffer << node
end
end
result = output.join
result = nil if result.empty?
[buffer, result]
end
def parse_tags(text)
parsed = []
while true
before, after = text.split("<", 2)
parsed << { type: :text, content: before }
break if after.nil?
tag, after = after.split(">", 2)
is_end_tag = tag[0] == "/"
tag_name = tag
tag_name = tag[1..-1] || "" if is_end_tag
if !after
found = false
if tag_name.length <= @longest_tag
@tags_to_strip.each do |tag_to_strip|
if tag_to_strip.start_with?(tag_name)
parsed << { type: :maybe_tag, content: "<" + tag }
found = true
break
end
end
end
parsed << { type: :text, content: "<" + tag } if !found
break
end
raw_tag = "<" + tag + ">"
if @tags_to_strip.include?(tag_name)
parsed << {
type: is_end_tag ? :close_tag : :open_tag,
content: raw_tag,
name: tag_name,
}
else
parsed << { type: :text, content: raw_tag }
end
text = after
end
parsed
end
end
end
end

View File

@ -1,6 +1,10 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
let :opus_dialect_klass do
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
end
describe "#translate" do
it "can insert OKs to make stuff interleve properly" do
messages = [
@ -13,8 +17,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
dialect = dialectKlass.new(prompt, "claude-3-opus")
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
translated = dialect.translate
expected_messages = [
@ -29,8 +32,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
expect(translated.messages).to eq(expected_messages)
end
it "can properly translate a prompt" do
dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
it "can properly translate a prompt (legacy tools)" do
SiteSetting.ai_anthropic_native_tool_call_models = ""
tools = [
{
@ -59,7 +62,59 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
tools: tools,
)
dialect = dialect.new(prompt, "claude-3-opus")
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")
expected = [
{ role: "user", content: "user1: echo something" },
{
role: "assistant",
content:
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
},
{
role: "user",
content:
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
},
{ role: "assistant", content: "I did it" },
{ role: "user", content: "user1: echo something else" },
]
expect(translated.messages).to eq(expected)
end
it "can properly translate a prompt (native tools)" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
tools = [
{
name: "echo",
description: "echo a string",
parameters: [
{ name: "text", type: "string", description: "string to echo", required: true },
],
},
]
tool_call_prompt = { name: "echo", arguments: { text: "something" } }
messages = [
{ type: :user, id: "user1", content: "echo something" },
{ type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "something".to_json },
{ type: :model, content: "I did it" },
{ type: :user, id: "user1", content: "echo something else" },
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful bot",
messages: messages,
tools: tools,
)
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")

View File

@ -48,6 +48,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
before { SiteSetting.ai_anthropic_api_key = "123" }
it "does not eat spaces with tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
body = <<~STRING
event: message_start
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }

View File

@ -18,6 +18,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
end
def encode_message(message)
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
@ -25,6 +32,85 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
end
describe "function calling" do
it "supports old school xml function calls" do
SiteSetting.ai_anthropic_native_tool_call_models = ""
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
incomplete_tool_call = <<~XML.strip
<thinking>I should be ignored</thinking>
<search_quality_reflection>also ignored</search_quality_reflection>
<search_quality_score>0</search_quality_score>
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
</invoke>
</function_calls>
XML
messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "hello\n" } },
{ type: "content_block_delta", delta: { text: incomplete_tool_call } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map { |message| encode_message(message) }
request = nil
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [{ type: :user, content: "what is the weather in sydney" }],
)
tool = {
name: "google",
description: "Will search using Google",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
prompt.tools = [tool]
response = +""
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
parsed_body = JSON.parse(request.body)
expect(parsed_body["system"]).to include("<function_calls>")
expect(parsed_body["tools"]).to eq(nil)
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
# note we now have a tool_id cause we were normalized
function_call = <<~XML.strip
hello
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(response.strip).to eq(function_call)
end
end
it "supports streaming function calls" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
@ -48,6 +134,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
stop_reason: nil,
},
},
{
type: "content_block_start",
index: 0,
delta: {
text: "<thinking>I should be ignored</thinking>",
},
},
{
type: "content_block_start",
index: 0,
@ -111,12 +204,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
firstByteLatency: 402,
},
},
].map do |message|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
].map { |message| encode_message(message) }
messages = messages.join("").split
@ -248,12 +336,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
{ type: "content_block_delta", delta: { text: "hello " } },
{ type: "content_block_delta", delta: { text: "sam" } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map do |message|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
].map { |message| encode_message(message) }
# stream 1 letter at a time
# cause we need to handle this case

View File

@ -0,0 +1,51 @@
# frozen_string_literal: true
describe DiscourseAi::Completions::PromptMessagesBuilder do
let(:tag_stripper) { DiscourseAi::Completions::XmlTagStripper.new(%w[thinking results]) }
it "should strip tags correctly in simple cases" do
result = tag_stripper << "x<thinking>hello</thinki"
expect(result).to eq("x")
result = tag_stripper << "ng>z"
expect(result).to eq("z")
result = tag_stripper << "king>hello</thinking>"
expect(result).to eq("king>hello</thinking>")
result = tag_stripper << "123"
expect(result).to eq("123")
end
it "supports odd nesting" do
text = <<~TEXT
<thinking>
well lets see what happens if I say <results> here...
</thinking>
hello
TEXT
result = tag_stripper << text
expect(result).to eq("\nhello\n")
end
it "works when nesting unrelated tags it strips correctly" do
text = <<~TEXT
<thinking>
well lets see what happens if I say <p> here...
</thinking>
abc <b>hello</b>
TEXT
result = tag_stripper << text
expect(result).to eq("\nabc <b>hello</b>\n")
end
it "handles maybe tags correctly" do
result = tag_stripper << "<thinking"
expect(result).to eq(nil)
expect(tag_stripper.finish).to eq("<thinking")
end
end