FIX: Add tool support to open ai compatible dialect and vllm (#734)

* FIX: Add tool support to open ai compatible dialect and vllm

Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649

Even when they are supported, initial support will be uneven, only some models have native tool support
notably mistral which has some special tokens for tool support.

After the above PR lands in vllm we will still need to swap to
XML based tools on models without native tool support.

* fix specs
This commit is contained in:
Sam 2024-08-02 22:52:33 +10:00 committed by GitHub
parent b7ac229547
commit 948cf893a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 27 deletions

View File

@ -27,7 +27,13 @@ module DiscourseAi
private
def system_msg(msg)
{ role: "system", content: msg[:content] }
msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def model_msg(msg)
@ -35,11 +41,13 @@ module DiscourseAi
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
translated = tools_dialect.from_raw_tool_call(msg)
{ role: "assistant", content: translated }
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
translated = tools_dialect.from_raw_tool(msg)
{ role: "user", content: translated }
end
def user_msg(msg)

View File

@ -31,7 +31,7 @@ module DiscourseAi
def model_uri
if llm_model.url.to_s.starts_with?("srv://")
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
api_endpoint = llm_model.url
@ -40,11 +40,11 @@ module DiscourseAi
@uri ||= URI(api_endpoint)
end
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
payload[:stream] = true if @streaming_mode
payload
end
def prepare_request(payload)

View File

@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.regular_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(hf_mock)
end
end
end
describe "when using streaming mode" do
@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.streaming_mode_simple_prompt(hf_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.streaming_mode_tools(hf_mock)
end
end
end
end
end

View File

@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) { described_class.new(llm_model) }
fab!(:llm_model) { Fabricate(:vllm_model) }
fab!(:user)
let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:vllm_mock) { VllmMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(
@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
describe "tool support" do
it "is able to invoke XML tools correctly" do
xml = <<~XML
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters>
<expression>1+1</expression></parameters>
</invoke>
</function_calls>
should be ignored
XML
body = {
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
created: 1_678_464_820,
model: "gpt-3.5-turbo-0301",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
total_tokens: 499,
},
choices: [
{ message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 },
],
}
tool = {
name: "calculate",
description: "calculate something",
parameters: [
{
name: "expression",
type: "string",
description: "expression to calculate",
required: true,
},
],
}
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200,
body: body.to_json,
)
prompt =
DiscourseAi::Completions::Prompt.new(
"You a calculator",
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
tools: [tool],
)
result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters>
<expression>1+1</expression></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
expect(result.strip).to eq(expected.strip)
end
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
compliance.regular_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock)
compliance.regular_mode_tools(vllm_mock)
end
end
end
@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock)
compliance.streaming_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock)
compliance.streaming_mode_tools(vllm_mock)
end
end
end