FIX: Add tool support to open ai compatible dialect and vllm (#734)
* FIX: Add tool support to open ai compatible dialect and vllm Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649 Even when they are supported, initial support will be uneven, only some models have native tool support notably mistral which has some special tokens for tool support. After the above PR lands in vllm we will still need to swap to XML based tools on models without native tool support. * fix specs
This commit is contained in:
parent
b7ac229547
commit
948cf893a9
|
@ -27,7 +27,13 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def system_msg(msg)
|
||||
{ role: "system", content: msg[:content] }
|
||||
msg = { role: "system", content: msg[:content] }
|
||||
|
||||
if tools_dialect.instructions.present?
|
||||
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
||||
end
|
||||
|
||||
msg
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
|
@ -35,11 +41,13 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
tools_dialect.from_raw_tool_call(msg)
|
||||
translated = tools_dialect.from_raw_tool_call(msg)
|
||||
{ role: "assistant", content: translated }
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
tools_dialect.from_raw_tool(msg)
|
||||
translated = tools_dialect.from_raw_tool(msg)
|
||||
{ role: "user", content: translated }
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
|
|
|
@ -31,7 +31,7 @@ module DiscourseAi
|
|||
|
||||
def model_uri
|
||||
if llm_model.url.to_s.starts_with?("srv://")
|
||||
record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
||||
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
||||
else
|
||||
api_endpoint = llm_model.url
|
||||
|
@ -40,11 +40,11 @@ module DiscourseAi
|
|||
@uri ||= URI(api_endpoint)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(messages: prompt)
|
||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
|
|
@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
compliance.regular_mode_simple_prompt(hf_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(hf_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
|
@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
compliance.streaming_mode_simple_prompt(hf_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.streaming_mode_tools(hf_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
subject(:endpoint) { described_class.new(llm_model) }
|
||||
|
||||
fab!(:llm_model) { Fabricate(:vllm_model) }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
let(:anthropic_mock) { VllmMock.new(endpoint) }
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
||||
let(:vllm_mock) { VllmMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(
|
||||
|
@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
||||
|
||||
describe "tool support" do
|
||||
it "is able to invoke XML tools correctly" do
|
||||
xml = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters>
|
||||
<expression>1+1</expression></parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
should be ignored
|
||||
XML
|
||||
|
||||
body = {
|
||||
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||
object: "chat.completion",
|
||||
created: 1_678_464_820,
|
||||
model: "gpt-3.5-turbo-0301",
|
||||
usage: {
|
||||
prompt_tokens: 337,
|
||||
completion_tokens: 162,
|
||||
total_tokens: 499,
|
||||
},
|
||||
choices: [
|
||||
{ message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 },
|
||||
],
|
||||
}
|
||||
tool = {
|
||||
name: "calculate",
|
||||
description: "calculate something",
|
||||
parameters: [
|
||||
{
|
||||
name: "expression",
|
||||
type: "string",
|
||||
description: "expression to calculate",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: body.to_json,
|
||||
)
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You a calculator",
|
||||
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
|
||||
tools: [tool],
|
||||
)
|
||||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters>
|
||||
<expression>1+1</expression></parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
expect(result.strip).to eq(expected.strip)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||
compliance.regular_mode_simple_prompt(vllm_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(anthropic_mock)
|
||||
compliance.regular_mode_tools(vllm_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||
compliance.streaming_mode_simple_prompt(vllm_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invoncation" do
|
||||
compliance.streaming_mode_tools(anthropic_mock)
|
||||
compliance.streaming_mode_tools(vllm_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue