FIX: switch off native tools on Anthropic Claude Opus (#659)
Native tools do not work well on Opus. Chain of Thought prompting means it consumes enormous amounts of tokens and has poor latency. This commit introduce and XML stripper to remove various chain of thought XML islands from anthropic prompts when tools are involved. This mean Opus native tools is now functions (albeit slowly) From local testing XML just works better now. Also fixes enum support in Anthropic native tools
This commit is contained in:
parent
7a64699314
commit
8b81ff45b8
|
@ -51,6 +51,7 @@ en:
|
|||
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
||||
ai_openai_api_key: "API key for OpenAI API"
|
||||
ai_anthropic_api_key: "API key for Anthropic API"
|
||||
ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools."
|
||||
ai_cohere_api_key: "API key for Cohere API"
|
||||
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
|
||||
ai_hugging_face_api_key: API key for Hugging Face API
|
||||
|
|
|
@ -111,6 +111,15 @@ discourse_ai:
|
|||
ai_anthropic_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_anthropic_native_tool_call_models:
|
||||
type: list
|
||||
list_type: compact
|
||||
default: "claude-3-sonnet|claude-3-haiku"
|
||||
allow_any: false
|
||||
choices:
|
||||
- claude-3-opus
|
||||
- claude-3-sonnet
|
||||
- claude-3-haiku
|
||||
ai_cohere_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
|
|
|
@ -22,6 +22,10 @@ module DiscourseAi
|
|||
@messages = messages
|
||||
@tools = tools
|
||||
end
|
||||
|
||||
def has_tools?
|
||||
tools.present?
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
|
@ -33,6 +37,10 @@ module DiscourseAi
|
|||
|
||||
system_prompt = messages.shift[:content] if messages.first[:role] == "system"
|
||||
|
||||
if !system_prompt && !native_tool_support?
|
||||
system_prompt = tools_dialect.instructions.presence
|
||||
end
|
||||
|
||||
interleving_messages = []
|
||||
previous_message = nil
|
||||
|
||||
|
@ -48,11 +56,10 @@ module DiscourseAi
|
|||
previous_message = message
|
||||
end
|
||||
|
||||
ClaudePrompt.new(
|
||||
system_prompt.presence,
|
||||
interleving_messages,
|
||||
tools_dialect.translated_tools,
|
||||
)
|
||||
tools = nil
|
||||
tools = tools_dialect.translated_tools if native_tool_support?
|
||||
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages, tools)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
|
@ -62,18 +69,28 @@ module DiscourseAi
|
|||
200_000 # Claude-3 has a 200k context window for now
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def tools_dialect
|
||||
if native_tool_support?
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
tools_dialect.from_raw_tool_call(msg)
|
||||
translated = tools_dialect.from_raw_tool_call(msg)
|
||||
{ role: "assistant", content: translated }
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
tools_dialect.from_raw_tool(msg)
|
||||
translated = tools_dialect.from_raw_tool(msg)
|
||||
{ role: "user", content: translated }
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
|
|
|
@ -15,12 +15,13 @@ module DiscourseAi
|
|||
required = []
|
||||
|
||||
if t[:parameters]
|
||||
properties =
|
||||
t[:parameters].each_with_object({}) do |param, h|
|
||||
h[param[:name]] = {
|
||||
type: param[:type],
|
||||
description: param[:description],
|
||||
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
|
||||
properties = {}
|
||||
|
||||
t[:parameters].each do |param|
|
||||
mapped = { type: param[:type], description: param[:description] }
|
||||
mapped[:items] = { type: param[:item_type] } if param[:item_type]
|
||||
mapped[:enum] = param[:enum] if param[:enum]
|
||||
properties[param[:name]] = mapped
|
||||
end
|
||||
required =
|
||||
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
|
||||
|
@ -39,37 +40,24 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def instructions
|
||||
"" # Noop. Tools are listed separate.
|
||||
""
|
||||
end
|
||||
|
||||
def from_raw_tool_call(raw_message)
|
||||
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
|
||||
tool_call_id = raw_message[:id]
|
||||
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
[
|
||||
{
|
||||
type: "tool_use",
|
||||
id: tool_call_id,
|
||||
name: raw_message[:name],
|
||||
input: call_details[:arguments],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
end
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: raw_message[:id],
|
||||
content: raw_message[:content],
|
||||
},
|
||||
],
|
||||
}
|
||||
[{ type: "tool_result", tool_use_id: raw_message[:id], content: raw_message[:content] }]
|
||||
end
|
||||
|
||||
private
|
||||
|
|
|
@ -41,7 +41,10 @@ module DiscourseAi
|
|||
def instructions
|
||||
return "" if raw_tools.blank?
|
||||
|
||||
has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
|
||||
@instructions ||=
|
||||
begin
|
||||
has_arrays =
|
||||
raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
|
||||
|
||||
(<<~TEXT).strip
|
||||
#{tool_preamble(include_array_tip: has_arrays)}
|
||||
|
@ -49,6 +52,7 @@ module DiscourseAi
|
|||
#{translated_tools}</tools>
|
||||
TEXT
|
||||
end
|
||||
end
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
(<<~TEXT).strip
|
||||
|
|
|
@ -45,7 +45,12 @@ module DiscourseAi
|
|||
raise "Unsupported model: #{model}"
|
||||
end
|
||||
|
||||
{ model: mapped_model, max_tokens: 3_000 }
|
||||
options = { model: mapped_model, max_tokens: 3_000 }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
||||
dialect.prompt.has_tools?
|
||||
|
||||
options
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -54,6 +59,14 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def xml_tags_to_strip(dialect)
|
||||
if dialect.prompt.has_tools?
|
||||
%w[thinking search_quality_reflection search_quality_score]
|
||||
else
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
# this is an approximation, we will update it later if request goes through
|
||||
def prompt_size(prompt)
|
||||
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
|
@ -66,11 +79,13 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||
payload[:tools] = prompt.tools if prompt.has_tools?
|
||||
|
||||
payload
|
||||
end
|
||||
|
@ -108,7 +123,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
|
|
|
@ -36,6 +36,9 @@ module DiscourseAi
|
|||
|
||||
def default_options(dialect)
|
||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
||||
dialect.prompt.has_tools?
|
||||
options
|
||||
end
|
||||
|
||||
|
@ -43,6 +46,14 @@ module DiscourseAi
|
|||
AiApiAuditLog::Provider::Anthropic
|
||||
end
|
||||
|
||||
def xml_tags_to_strip(dialect)
|
||||
if dialect.prompt.has_tools?
|
||||
%w[thinking search_quality_reflection search_quality_score]
|
||||
else
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def prompt_size(prompt)
|
||||
|
@ -79,9 +90,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||
payload[:tools] = prompt.tools if prompt.has_tools?
|
||||
|
||||
payload
|
||||
end
|
||||
|
@ -169,7 +182,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
|
|
|
@ -78,11 +78,27 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def xml_tags_to_strip(dialect)
|
||||
[]
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
orig_blk = blk
|
||||
|
||||
@streaming_mode = block_given?
|
||||
to_strip = xml_tags_to_strip(dialect)
|
||||
@xml_stripper =
|
||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||
|
||||
if @streaming_mode && @xml_stripper
|
||||
blk =
|
||||
lambda do |partial, cancel|
|
||||
partial = @xml_stripper << partial
|
||||
orig_blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
|
||||
prompt = dialect.translate
|
||||
|
||||
|
@ -270,6 +286,11 @@ module DiscourseAi
|
|||
blk.call(function_calls, cancel)
|
||||
end
|
||||
|
||||
if @xml_stripper
|
||||
leftover = @xml_stripper.finish
|
||||
orig_blk.call(leftover, cancel) if leftover.present?
|
||||
end
|
||||
|
||||
return response_data
|
||||
ensure
|
||||
if log
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class XmlTagStripper
|
||||
def initialize(tags_to_strip)
|
||||
@tags_to_strip = tags_to_strip
|
||||
@longest_tag = tags_to_strip.map(&:length).max
|
||||
@parsed = []
|
||||
end
|
||||
|
||||
def <<(text)
|
||||
if node = @parsed[-1]
|
||||
if node[:type] == :maybe_tag
|
||||
@parsed.pop
|
||||
text = node[:content] + text
|
||||
end
|
||||
end
|
||||
@parsed.concat(parse_tags(text))
|
||||
@parsed, result = process_parsed(@parsed)
|
||||
result
|
||||
end
|
||||
|
||||
def finish
|
||||
@parsed.map { |node| node[:content] }.join
|
||||
end
|
||||
|
||||
def process_parsed(parsed)
|
||||
output = []
|
||||
buffer = []
|
||||
stack = []
|
||||
|
||||
parsed.each do |node|
|
||||
case node[:type]
|
||||
when :text
|
||||
if stack.empty?
|
||||
output << node[:content]
|
||||
else
|
||||
buffer << node
|
||||
end
|
||||
when :open_tag
|
||||
stack << node[:name]
|
||||
buffer << node
|
||||
when :close_tag
|
||||
if stack.empty?
|
||||
output << node[:content]
|
||||
else
|
||||
if stack[0] == node[:name]
|
||||
buffer = []
|
||||
stack = []
|
||||
else
|
||||
buffer << node
|
||||
end
|
||||
end
|
||||
when :maybe_tag
|
||||
buffer << node
|
||||
end
|
||||
end
|
||||
|
||||
result = output.join
|
||||
result = nil if result.empty?
|
||||
|
||||
[buffer, result]
|
||||
end
|
||||
|
||||
def parse_tags(text)
|
||||
parsed = []
|
||||
|
||||
while true
|
||||
before, after = text.split("<", 2)
|
||||
|
||||
parsed << { type: :text, content: before }
|
||||
|
||||
break if after.nil?
|
||||
|
||||
tag, after = after.split(">", 2)
|
||||
|
||||
is_end_tag = tag[0] == "/"
|
||||
tag_name = tag
|
||||
tag_name = tag[1..-1] || "" if is_end_tag
|
||||
|
||||
if !after
|
||||
found = false
|
||||
if tag_name.length <= @longest_tag
|
||||
@tags_to_strip.each do |tag_to_strip|
|
||||
if tag_to_strip.start_with?(tag_name)
|
||||
parsed << { type: :maybe_tag, content: "<" + tag }
|
||||
found = true
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
parsed << { type: :text, content: "<" + tag } if !found
|
||||
break
|
||||
end
|
||||
|
||||
raw_tag = "<" + tag + ">"
|
||||
|
||||
if @tags_to_strip.include?(tag_name)
|
||||
parsed << {
|
||||
type: is_end_tag ? :close_tag : :open_tag,
|
||||
content: raw_tag,
|
||||
name: tag_name,
|
||||
}
|
||||
else
|
||||
parsed << { type: :text, content: raw_tag }
|
||||
end
|
||||
text = after
|
||||
end
|
||||
|
||||
parsed
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,6 +1,10 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||
let :opus_dialect_klass do
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
end
|
||||
|
||||
describe "#translate" do
|
||||
it "can insert OKs to make stuff interleve properly" do
|
||||
messages = [
|
||||
|
@ -13,8 +17,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
dialect = dialectKlass.new(prompt, "claude-3-opus")
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expected_messages = [
|
||||
|
@ -29,8 +32,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
expect(translated.messages).to eq(expected_messages)
|
||||
end
|
||||
|
||||
it "can properly translate a prompt" do
|
||||
dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
it "can properly translate a prompt (legacy tools)" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
||||
|
||||
tools = [
|
||||
{
|
||||
|
@ -59,7 +62,59 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
tools: tools,
|
||||
)
|
||||
|
||||
dialect = dialect.new(prompt, "claude-3-opus")
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
|
||||
expected = [
|
||||
{ role: "user", content: "user1: echo something" },
|
||||
{
|
||||
role: "assistant",
|
||||
content:
|
||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
|
||||
},
|
||||
{ role: "assistant", content: "I did it" },
|
||||
{ role: "user", content: "user1: echo something else" },
|
||||
]
|
||||
expect(translated.messages).to eq(expected)
|
||||
end
|
||||
|
||||
it "can properly translate a prompt (native tools)" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
|
||||
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo a string",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "string to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tool_call_prompt = { name: "echo", arguments: { text: "something" } }
|
||||
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "echo something" },
|
||||
{ type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
|
||||
{ type: :tool, id: "tool_id", content: "something".to_json },
|
||||
{ type: :model, content: "I did it" },
|
||||
{ type: :user, id: "user1", content: "echo something else" },
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful bot",
|
||||
messages: messages,
|
||||
tools: tools,
|
||||
)
|
||||
dialect = opus_dialect_klass.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
|
|
|
@ -48,6 +48,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
before { SiteSetting.ai_anthropic_api_key = "123" }
|
||||
|
||||
it "does not eat spaces with tool calls" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
|
||||
body = <<~STRING
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
|
||||
|
|
|
@ -18,6 +18,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||
end
|
||||
|
||||
def encode_message(message)
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||
|
@ -25,6 +32,85 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
end
|
||||
|
||||
describe "function calling" do
|
||||
it "supports old school xml function calls" do
|
||||
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
||||
incomplete_tool_call = <<~XML.strip
|
||||
<thinking>I should be ignored</thinking>
|
||||
<search_quality_reflection>also ignored</search_quality_reflection>
|
||||
<search_quality_score>0</search_quality_score>
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>sydney weather today</query></parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
messages =
|
||||
[
|
||||
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
||||
{ type: "content_block_delta", delta: { text: "hello\n" } },
|
||||
{ type: "content_block_delta", delta: { text: incomplete_tool_call } },
|
||||
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
request = nil
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
||||
)
|
||||
|
||||
tool = {
|
||||
name: "google",
|
||||
description: "Will search using Google",
|
||||
parameters: [
|
||||
{ name: "query", description: "The search query", type: "string", required: true },
|
||||
],
|
||||
}
|
||||
|
||||
prompt.tools = [tool]
|
||||
response = +""
|
||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
parsed_body = JSON.parse(request.body)
|
||||
expect(parsed_body["system"]).to include("<function_calls>")
|
||||
expect(parsed_body["tools"]).to eq(nil)
|
||||
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
||||
|
||||
# note we now have a tool_id cause we were normalized
|
||||
function_call = <<~XML.strip
|
||||
hello
|
||||
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>sydney weather today</query></parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(response.strip).to eq(function_call)
|
||||
end
|
||||
end
|
||||
|
||||
it "supports streaming function calls" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
||||
|
@ -48,6 +134,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
stop_reason: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
delta: {
|
||||
text: "<thinking>I should be ignored</thinking>",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
|
@ -111,12 +204,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
firstByteLatency: 402,
|
||||
},
|
||||
},
|
||||
].map do |message|
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
messages = messages.join("").split
|
||||
|
||||
|
@ -248,12 +336,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
{ type: "content_block_delta", delta: { text: "hello " } },
|
||||
{ type: "content_block_delta", delta: { text: "sam" } },
|
||||
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
||||
].map do |message|
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
# stream 1 letter at a time
|
||||
# cause we need to handle this case
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Completions::PromptMessagesBuilder do
|
||||
let(:tag_stripper) { DiscourseAi::Completions::XmlTagStripper.new(%w[thinking results]) }
|
||||
|
||||
it "should strip tags correctly in simple cases" do
|
||||
result = tag_stripper << "x<thinking>hello</thinki"
|
||||
expect(result).to eq("x")
|
||||
|
||||
result = tag_stripper << "ng>z"
|
||||
expect(result).to eq("z")
|
||||
|
||||
result = tag_stripper << "king>hello</thinking>"
|
||||
expect(result).to eq("king>hello</thinking>")
|
||||
|
||||
result = tag_stripper << "123"
|
||||
expect(result).to eq("123")
|
||||
end
|
||||
|
||||
it "supports odd nesting" do
|
||||
text = <<~TEXT
|
||||
<thinking>
|
||||
well lets see what happens if I say <results> here...
|
||||
</thinking>
|
||||
hello
|
||||
TEXT
|
||||
|
||||
result = tag_stripper << text
|
||||
expect(result).to eq("\nhello\n")
|
||||
end
|
||||
|
||||
it "works when nesting unrelated tags it strips correctly" do
|
||||
text = <<~TEXT
|
||||
<thinking>
|
||||
well lets see what happens if I say <p> here...
|
||||
</thinking>
|
||||
abc <b>hello</b>
|
||||
TEXT
|
||||
|
||||
result = tag_stripper << text
|
||||
|
||||
expect(result).to eq("\nabc <b>hello</b>\n")
|
||||
end
|
||||
|
||||
it "handles maybe tags correctly" do
|
||||
result = tag_stripper << "<thinking"
|
||||
expect(result).to eq(nil)
|
||||
|
||||
expect(tag_stripper.finish).to eq("<thinking")
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue