mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-02 07:19:42 +00:00
* FIX: Add tool support to open ai compatible dialect and vllm Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649 Even when they are supported, initial support will be uneven, only some models have native tool support notably mistral which has some special tokens for tool support. After the above PR lands in vllm we will still need to swap to XML based tools on models without native tool support. * fix specs
184 lines
5.0 KiB
Ruby
184 lines
5.0 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require_relative "endpoint_compliance"
|
|
|
|
class VllmMock < EndpointMock
|
|
def response(content)
|
|
{
|
|
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
|
object: "chat.completion",
|
|
created: 1_678_464_820,
|
|
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
usage: {
|
|
prompt_tokens: 337,
|
|
completion_tokens: 162,
|
|
total_tokens: 499,
|
|
},
|
|
choices: [
|
|
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
|
|
],
|
|
}
|
|
end
|
|
|
|
def stub_response(prompt, response_text, tool_call: false)
|
|
WebMock
|
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
|
.with(body: model.default_options.merge(messages: prompt).to_json)
|
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
|
end
|
|
|
|
def stream_line(delta, finish_reason: nil)
|
|
+"data: " << {
|
|
id: "cmpl-#{SecureRandom.hex}",
|
|
created: 1_681_283_881,
|
|
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
choices: [{ delta: { content: delta } }],
|
|
index: 0,
|
|
}.to_json
|
|
end
|
|
|
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
|
chunks =
|
|
deltas.each_with_index.map do |_, index|
|
|
if index == (deltas.length - 1)
|
|
stream_line(deltas[index], finish_reason: "stop_sequence")
|
|
else
|
|
stream_line(deltas[index])
|
|
end
|
|
end
|
|
|
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
|
|
|
WebMock
|
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
|
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
|
.to_return(status: 200, body: chunks)
|
|
end
|
|
end
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|
subject(:endpoint) { described_class.new(llm_model) }
|
|
|
|
fab!(:llm_model) { Fabricate(:vllm_model) }
|
|
fab!(:user)
|
|
|
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") }
|
|
let(:vllm_mock) { VllmMock.new(endpoint) }
|
|
|
|
let(:compliance) do
|
|
EndpointsCompliance.new(
|
|
self,
|
|
endpoint,
|
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
|
user,
|
|
)
|
|
end
|
|
|
|
let(:dialect) do
|
|
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, llm_model)
|
|
end
|
|
let(:prompt) { dialect.translate }
|
|
|
|
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
|
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
|
|
|
describe "tool support" do
|
|
it "is able to invoke XML tools correctly" do
|
|
xml = <<~XML
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>calculate</tool_name>
|
|
<parameters>
|
|
<expression>1+1</expression></parameters>
|
|
</invoke>
|
|
</function_calls>
|
|
should be ignored
|
|
XML
|
|
|
|
body = {
|
|
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
|
object: "chat.completion",
|
|
created: 1_678_464_820,
|
|
model: "gpt-3.5-turbo-0301",
|
|
usage: {
|
|
prompt_tokens: 337,
|
|
completion_tokens: 162,
|
|
total_tokens: 499,
|
|
},
|
|
choices: [
|
|
{ message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 },
|
|
],
|
|
}
|
|
tool = {
|
|
name: "calculate",
|
|
description: "calculate something",
|
|
parameters: [
|
|
{
|
|
name: "expression",
|
|
type: "string",
|
|
description: "expression to calculate",
|
|
required: true,
|
|
},
|
|
],
|
|
}
|
|
|
|
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
|
status: 200,
|
|
body: body.to_json,
|
|
)
|
|
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You a calculator",
|
|
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
|
|
tools: [tool],
|
|
)
|
|
|
|
result = llm.generate(prompt, user: Discourse.system_user)
|
|
|
|
expected = <<~TEXT
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>calculate</tool_name>
|
|
<parameters>
|
|
<expression>1+1</expression></parameters>
|
|
<tool_id>tool_0</tool_id>
|
|
</invoke>
|
|
</function_calls>
|
|
TEXT
|
|
|
|
expect(result.strip).to eq(expected.strip)
|
|
end
|
|
end
|
|
|
|
describe "#perform_completion!" do
|
|
context "when using regular mode" do
|
|
context "with simple prompts" do
|
|
it "completes a trivial prompt and logs the response" do
|
|
compliance.regular_mode_simple_prompt(vllm_mock)
|
|
end
|
|
end
|
|
|
|
context "with tools" do
|
|
it "returns a function invocation" do
|
|
compliance.regular_mode_tools(vllm_mock)
|
|
end
|
|
end
|
|
end
|
|
|
|
describe "when using streaming mode" do
|
|
context "with simple prompts" do
|
|
it "completes a trivial prompt and logs the response" do
|
|
compliance.streaming_mode_simple_prompt(vllm_mock)
|
|
end
|
|
end
|
|
|
|
context "with tools" do
|
|
it "returns a function invoncation" do
|
|
compliance.streaming_mode_tools(vllm_mock)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|