FIX: Add tool support to open ai compatible dialect and vllm (#734)

* FIX: Add tool support to open ai compatible dialect and vllm

Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649

Even when they are supported, initial support will be uneven, only some models have native tool support
notably mistral which has some special tokens for tool support.

After the above PR lands in vllm we will still need to swap to
XML based tools on models without native tool support.

* fix specs
This commit is contained in:
Sam 2024-08-02 22:52:33 +10:00 committed by GitHub
parent b7ac229547
commit 948cf893a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 27 deletions

View File

@ -27,7 +27,13 @@ module DiscourseAi
private private
def system_msg(msg) def system_msg(msg)
{ role: "system", content: msg[:content] } msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end end
def model_msg(msg) def model_msg(msg)
@ -35,11 +41,13 @@ module DiscourseAi
end end
def tool_call_msg(msg) def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg) translated = tools_dialect.from_raw_tool_call(msg)
{ role: "assistant", content: translated }
end end
def tool_msg(msg) def tool_msg(msg)
tools_dialect.from_raw_tool(msg) translated = tools_dialect.from_raw_tool(msg)
{ role: "user", content: translated }
end end
def user_msg(msg) def user_msg(msg)

View File

@ -31,7 +31,7 @@ module DiscourseAi
def model_uri def model_uri
if llm_model.url.to_s.starts_with?("srv://") if llm_model.url.to_s.starts_with?("srv://")
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", "")) service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else else
api_endpoint = llm_model.url api_endpoint = llm_model.url
@ -40,11 +40,11 @@ module DiscourseAi
@uri ||= URI(api_endpoint) @uri ||= URI(api_endpoint)
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, dialect)
default_options payload = default_options.merge(model_params).merge(messages: prompt)
.merge(model_params) payload[:stream] = true if @streaming_mode
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode } payload
end end
def prepare_request(payload) def prepare_request(payload)

View File

@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.regular_mode_simple_prompt(hf_mock) compliance.regular_mode_simple_prompt(hf_mock)
end end
end end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(hf_mock)
end
end
end end
describe "when using streaming mode" do describe "when using streaming mode" do
@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.streaming_mode_simple_prompt(hf_mock) compliance.streaming_mode_simple_prompt(hf_mock)
end end
end end
context "with tools" do
it "returns a function invocation" do
compliance.streaming_mode_tools(hf_mock)
end
end
end end
end end
end end

View File

@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) { described_class.new(llm_model) } subject(:endpoint) { described_class.new(llm_model) }
fab!(:llm_model) { Fabricate(:vllm_model) } fab!(:llm_model) { Fabricate(:vllm_model) }
fab!(:user) fab!(:user)
let(:anthropic_mock) { VllmMock.new(endpoint) } let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
let(:vllm_mock) { VllmMock.new(endpoint) }
let(:compliance) do let(:compliance) do
EndpointsCompliance.new( EndpointsCompliance.new(
@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
describe "tool support" do
it "is able to invoke XML tools correctly" do
xml = <<~XML
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters>
<expression>1+1</expression></parameters>
</invoke>
</function_calls>
should be ignored
XML
body = {
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
created: 1_678_464_820,
model: "gpt-3.5-turbo-0301",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
total_tokens: 499,
},
choices: [
{ message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 },
],
}
tool = {
name: "calculate",
description: "calculate something",
parameters: [
{
name: "expression",
type: "string",
description: "expression to calculate",
required: true,
},
],
}
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200,
body: body.to_json,
)
prompt =
DiscourseAi::Completions::Prompt.new(
"You a calculator",
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
tools: [tool],
)
result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters>
<expression>1+1</expression></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
expect(result.strip).to eq(expected.strip)
end
end
describe "#perform_completion!" do describe "#perform_completion!" do
context "when using regular mode" do context "when using regular mode" do
context "with simple prompts" do context "with simple prompts" do
it "completes a trivial prompt and logs the response" do it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock) compliance.regular_mode_simple_prompt(vllm_mock)
end end
end end
context "with tools" do context "with tools" do
it "returns a function invocation" do it "returns a function invocation" do
compliance.regular_mode_tools(anthropic_mock) compliance.regular_mode_tools(vllm_mock)
end end
end end
end end
@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
describe "when using streaming mode" do describe "when using streaming mode" do
context "with simple prompts" do context "with simple prompts" do
it "completes a trivial prompt and logs the response" do it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(anthropic_mock) compliance.streaming_mode_simple_prompt(vllm_mock)
end end
end end
context "with tools" do context "with tools" do
it "returns a function invoncation" do it "returns a function invoncation" do
compliance.streaming_mode_tools(anthropic_mock) compliance.streaming_mode_tools(vllm_mock)
end end
end end
end end