From 948cf893a9d7b9293ba665c6368c3e981499ae90 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 2 Aug 2024 22:52:33 +1000 Subject: [PATCH] FIX: Add tool support to open ai compatible dialect and vllm (#734) * FIX: Add tool support to open ai compatible dialect and vllm Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649 Even when they are supported, initial support will be uneven, only some models have native tool support notably mistral which has some special tokens for tool support. After the above PR lands in vllm we will still need to swap to XML based tools on models without native tool support. * fix specs --- .../dialects/open_ai_compatible.rb | 14 +++- lib/completions/endpoints/vllm.rb | 12 +-- .../endpoints/hugging_face_spec.rb | 12 --- spec/lib/completions/endpoints/vllm_spec.rb | 81 +++++++++++++++++-- 4 files changed, 92 insertions(+), 27 deletions(-) diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index ec91c49a..33af1a14 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -27,7 +27,13 @@ module DiscourseAi private def system_msg(msg) - { role: "system", content: msg[:content] } + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg end def model_msg(msg) @@ -35,11 +41,13 @@ module DiscourseAi end def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } end def tool_msg(msg) - tools_dialect.from_raw_tool(msg) + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } end def user_msg(msg) diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 61abf5ad..57fcf051 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -31,7 +31,7 @@ module DiscourseAi def model_uri if llm_model.url.to_s.starts_with?("srv://") - record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", "")) + service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", "")) api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" else api_endpoint = llm_model.url @@ -40,11 +40,11 @@ module DiscourseAi @uri ||= URI(api_endpoint) end - def prepare_payload(prompt, model_params, _dialect) - default_options - .merge(model_params) - .merge(messages: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + def prepare_payload(prompt, model_params, dialect) + payload = default_options.merge(model_params).merge(messages: prompt) + payload[:stream] = true if @streaming_mode + + payload end def prepare_request(payload) diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 150db5e9..df1ee481 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do compliance.regular_mode_simple_prompt(hf_mock) end end - - context "with tools" do - it "returns a function invocation" do - compliance.regular_mode_tools(hf_mock) - end - end end describe "when using streaming mode" do @@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do compliance.streaming_mode_simple_prompt(hf_mock) end end - - context "with tools" do - it "returns a function invocation" do - compliance.streaming_mode_tools(hf_mock) - end - end end end end diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 1b89c3e1..6f5387c0 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do subject(:endpoint) { described_class.new(llm_model) } fab!(:llm_model) { Fabricate(:vllm_model) } - fab!(:user) - let(:anthropic_mock) { VllmMock.new(endpoint) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } + let(:vllm_mock) { VllmMock.new(endpoint) } let(:compliance) do EndpointsCompliance.new( @@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } + describe "tool support" do + it "is able to invoke XML tools correctly" do + xml = <<~XML + + + calculate + + 1+1 + + + should be ignored + XML + + body = { + id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", + object: "chat.completion", + created: 1_678_464_820, + model: "gpt-3.5-turbo-0301", + usage: { + prompt_tokens: 337, + completion_tokens: 162, + total_tokens: 499, + }, + choices: [ + { message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 }, + ], + } + tool = { + name: "calculate", + description: "calculate something", + parameters: [ + { + name: "expression", + type: "string", + description: "expression to calculate", + required: true, + }, + ], + } + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return( + status: 200, + body: body.to_json, + ) + + prompt = + DiscourseAi::Completions::Prompt.new( + "You a calculator", + messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }], + tools: [tool], + ) + + result = llm.generate(prompt, user: Discourse.system_user) + + expected = <<~TEXT + + + calculate + + 1+1 + tool_0 + + + TEXT + + expect(result.strip).to eq(expected.strip) + end + end + describe "#perform_completion!" do context "when using regular mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(anthropic_mock) + compliance.regular_mode_simple_prompt(vllm_mock) end end context "with tools" do it "returns a function invocation" do - compliance.regular_mode_tools(anthropic_mock) + compliance.regular_mode_tools(vllm_mock) end end end @@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do describe "when using streaming mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(anthropic_mock) + compliance.streaming_mode_simple_prompt(vllm_mock) end end context "with tools" do it "returns a function invoncation" do - compliance.streaming_mode_tools(anthropic_mock) + compliance.streaming_mode_tools(vllm_mock) end end end