diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index ec91c49a..33af1a14 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -27,7 +27,13 @@ module DiscourseAi private def system_msg(msg) - { role: "system", content: msg[:content] } + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg end def model_msg(msg) @@ -35,11 +41,13 @@ module DiscourseAi end def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } end def tool_msg(msg) - tools_dialect.from_raw_tool(msg) + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } end def user_msg(msg) diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 61abf5ad..57fcf051 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -31,7 +31,7 @@ module DiscourseAi def model_uri if llm_model.url.to_s.starts_with?("srv://") - record = service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", "")) + service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", "")) api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" else api_endpoint = llm_model.url @@ -40,11 +40,11 @@ module DiscourseAi @uri ||= URI(api_endpoint) end - def prepare_payload(prompt, model_params, _dialect) - default_options - .merge(model_params) - .merge(messages: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + def prepare_payload(prompt, model_params, dialect) + payload = default_options.merge(model_params).merge(messages: prompt) + payload[:stream] = true if @streaming_mode + + payload end def prepare_request(payload) diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 150db5e9..df1ee481 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -102,12 +102,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do compliance.regular_mode_simple_prompt(hf_mock) end end - - context "with tools" do - it "returns a function invocation" do - compliance.regular_mode_tools(hf_mock) - end - end end describe "when using streaming mode" do @@ -116,12 +110,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do compliance.streaming_mode_simple_prompt(hf_mock) end end - - context "with tools" do - it "returns a function invocation" do - compliance.streaming_mode_tools(hf_mock) - end - end end end end diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 1b89c3e1..6f5387c0 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -60,10 +60,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do subject(:endpoint) { described_class.new(llm_model) } fab!(:llm_model) { Fabricate(:vllm_model) } - fab!(:user) - let(:anthropic_mock) { VllmMock.new(endpoint) } + let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{llm_model.id}") } + let(:vllm_mock) { VllmMock.new(endpoint) } let(:compliance) do EndpointsCompliance.new( @@ -82,17 +82,86 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } + describe "tool support" do + it "is able to invoke XML tools correctly" do + xml = <<~XML + + + calculate + + 1+1 + + + should be ignored + XML + + body = { + id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", + object: "chat.completion", + created: 1_678_464_820, + model: "gpt-3.5-turbo-0301", + usage: { + prompt_tokens: 337, + completion_tokens: 162, + total_tokens: 499, + }, + choices: [ + { message: { role: "assistant", content: xml }, finish_reason: "stop", index: 0 }, + ], + } + tool = { + name: "calculate", + description: "calculate something", + parameters: [ + { + name: "expression", + type: "string", + description: "expression to calculate", + required: true, + }, + ], + } + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return( + status: 200, + body: body.to_json, + ) + + prompt = + DiscourseAi::Completions::Prompt.new( + "You a calculator", + messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }], + tools: [tool], + ) + + result = llm.generate(prompt, user: Discourse.system_user) + + expected = <<~TEXT + + + calculate + + 1+1 + tool_0 + + + TEXT + + expect(result.strip).to eq(expected.strip) + end + end + describe "#perform_completion!" do context "when using regular mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(anthropic_mock) + compliance.regular_mode_simple_prompt(vllm_mock) end end context "with tools" do it "returns a function invocation" do - compliance.regular_mode_tools(anthropic_mock) + compliance.regular_mode_tools(vllm_mock) end end end @@ -100,13 +169,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do describe "when using streaming mode" do context "with simple prompts" do it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(anthropic_mock) + compliance.streaming_mode_simple_prompt(vllm_mock) end end context "with tools" do it "returns a function invoncation" do - compliance.streaming_mode_tools(anthropic_mock) + compliance.streaming_mode_tools(vllm_mock) end end end