diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb
index ec91c49a..33af1a14 100644
--- a/lib/completions/dialects/open_ai_compatible.rb
+++ b/lib/completions/dialects/open_ai_compatible.rb
@@ -27,7 +27,13 @@ module DiscourseAi
private
def system_msg(msg)
- { role: "system", content: msg[:content] }
+ msg = { role: "system", content: msg[:content] }
+
+ if tools_dialect.instructions.present?
+ msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
+ end
+
+ msg
end
def model_msg(msg)
@@ -35,11 +41,13 @@ module DiscourseAi
end
def tool_call_msg(msg)
- tools_dialect.from_raw_tool_call(msg)
+ translated = tools_dialect.from_raw_tool_call(msg)
+ { role: "assistant", content: translated }
end
def tool_msg(msg)
- tools_dialect.from_raw_tool(msg)
+ translated = tools_dialect.from_raw_tool(msg)
+ { role: "user", content: translated }
end
def user_msg(msg)
diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb
index 61abf5ad..57fcf051 100644
--- a/lib/completions/endpoints/vllm.rb
+++ b/lib/completions/endpoints/vllm.rb
@@ -31,7 +31,7 @@ module DiscourseAi
def model_uri
if llm_model.url.to_s.starts_with?("srv://")
- record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
+ service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
api_endpoint = llm_model.url
@@ -40,11 +40,11 @@ module DiscourseAi
@uri ||= URI(api_endpoint)
end
- def prepare_payload(prompt, model_params, _dialect)
- default_options
- .merge(model_params)
- .merge(messages: prompt)
- .tap { |payload| payload[:stream] = true if @streaming_mode }
+ def prepare_payload(prompt, model_params, dialect)
+ payload = default_options.merge(model_params).merge(messages: prompt)
+ payload[:stream] = true if @streaming_mode
+
+ payload
end
def prepare_request(payload)
diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb
index 150db5e9..df1ee481 100644
--- a/spec/lib/completions/endpoints/hugging_face_spec.rb
+++ b/spec/lib/completions/endpoints/hugging_face_spec.rb
@@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.regular_mode_simple_prompt(hf_mock)
end
end
-
- context "with tools" do
- it "returns a function invocation" do
- compliance.regular_mode_tools(hf_mock)
- end
- end
end
describe "when using streaming mode" do
@@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
compliance.streaming_mode_simple_prompt(hf_mock)
end
end
-
- context "with tools" do
- it "returns a function invocation" do
- compliance.streaming_mode_tools(hf_mock)
- end
- end
end
end
end
diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb
index 1b89c3e1..6f5387c0 100644
--- a/spec/lib/completions/endpoints/vllm_spec.rb
+++ b/spec/lib/completions/endpoints/vllm_spec.rb
@@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:endpoint) { described_class.new(llm_model) }
fab!(:llm_model) { Fabricate(:vllm_model) }
-
fab!(:user)
- let(:anthropic_mock) { VllmMock.new(endpoint) }
+ let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
+ let(:vllm_mock) { VllmMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(
@@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
+ describe "tool support" do
+ it "is able to invoke XML tools correctly" do
+ xml = <<~XML
+
+
+ calculate
+
+ 1+1
+
+
+ should be ignored
+ XML
+
+ body = {
+ id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
+ object: "chat.completion",
+ created: 1_678_464_820,
+ model: "gpt-3.5-turbo-0301",
+ usage: {
+ prompt_tokens: 337,
+ completion_tokens: 162,
+ total_tokens: 499,
+ },
+ choices: [
+ { message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 },
+ ],
+ }
+ tool = {
+ name: "calculate",
+ description: "calculate something",
+ parameters: [
+ {
+ name: "expression",
+ type: "string",
+ description: "expression to calculate",
+ required: true,
+ },
+ ],
+ }
+
+ stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
+ status: 200,
+ body: body.to_json,
+ )
+
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ "You a calculator",
+ messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
+ tools: [tool],
+ )
+
+ result = llm.generate(prompt, user: Discourse.system_user)
+
+ expected = <<~TEXT
+
+
+ calculate
+
+ 1+1
+ tool_0
+
+
+ TEXT
+
+ expect(result.strip).to eq(expected.strip)
+ end
+ end
+
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
- compliance.regular_mode_simple_prompt(anthropic_mock)
+ compliance.regular_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do
it "returns a function invocation" do
- compliance.regular_mode_tools(anthropic_mock)
+ compliance.regular_mode_tools(vllm_mock)
end
end
end
@@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
describe "when using streaming mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
- compliance.streaming_mode_simple_prompt(anthropic_mock)
+ compliance.streaming_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do
it "returns a function invoncation" do
- compliance.streaming_mode_tools(anthropic_mock)
+ compliance.streaming_mode_tools(vllm_mock)
end
end
end