mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-25 05:47:06 +00:00
This PR adds support for disabling further tool calls by setting tool_choice to :none across all supported LLM providers: - OpenAI: Uses "none" tool_choice parameter - Anthropic: Uses {type: "none"} and adds a prefill message to prevent confusion - Gemini: Sets function_calling_config mode to "NONE" - AWS Bedrock: Doesn't natively support tool disabling, so adds a prefill message We previously used to disable tool calls by simply removing tool definitions, but this would cause errors with some providers. This implementation uses the supported method appropriate for each provider while providing a fallback for Bedrock. Co-authored-by: Natalie Tay <natalie.tay@gmail.com> * remove stray puts * cleaner chain breaker for last tool call (works in thinking) remove unused code * improve test --------- Co-authored-by: Natalie Tay <natalie.tay@gmail.com>
437 lines
14 KiB
Ruby
437 lines
14 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require_relative "endpoint_compliance"
|
|
|
|
class GeminiMock < EndpointMock
|
|
def response(content, tool_call: false)
|
|
{
|
|
candidates: [
|
|
{
|
|
content: {
|
|
parts: [(tool_call ? content : { text: content })],
|
|
role: "model",
|
|
},
|
|
finishReason: "STOP",
|
|
index: 0,
|
|
safetyRatings: [
|
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
|
],
|
|
},
|
|
],
|
|
promptFeedback: {
|
|
safetyRatings: [
|
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
|
],
|
|
},
|
|
}
|
|
end
|
|
|
|
def stub_response(prompt, response_text, tool_call: false)
|
|
WebMock
|
|
.stub_request(
|
|
:post,
|
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
|
)
|
|
.with(body: request_body(prompt, tool_call))
|
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
|
end
|
|
|
|
def stream_line(delta, finish_reason: nil, tool_call: false)
|
|
{
|
|
candidates: [
|
|
{
|
|
content: {
|
|
parts: [(tool_call ? delta : { text: delta })],
|
|
role: "model",
|
|
},
|
|
finishReason: finish_reason,
|
|
index: 0,
|
|
safetyRatings: [
|
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
|
],
|
|
},
|
|
],
|
|
}.to_json
|
|
end
|
|
|
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
|
chunks =
|
|
deltas.each_with_index.map do |_, index|
|
|
if index == (deltas.length - 1)
|
|
stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
|
|
else
|
|
stream_line(deltas[index], tool_call: tool_call)
|
|
end
|
|
end
|
|
|
|
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]").split("")
|
|
|
|
WebMock
|
|
.stub_request(
|
|
:post,
|
|
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
|
)
|
|
.with(body: request_body(prompt, tool_call))
|
|
.to_return(status: 200, body: chunks)
|
|
end
|
|
|
|
def tool_payload
|
|
{
|
|
name: "get_weather",
|
|
description: "Get the weather in a city",
|
|
parameters: {
|
|
type: "object",
|
|
required: %w[location unit],
|
|
properties: {
|
|
"location" => {
|
|
type: "string",
|
|
description: "the city name",
|
|
},
|
|
"unit" => {
|
|
type: "string",
|
|
description: "the unit of measurement celcius c or fahrenheit f",
|
|
enum: %w[c f],
|
|
},
|
|
},
|
|
},
|
|
}
|
|
end
|
|
|
|
def request_body(prompt, tool_call)
|
|
model
|
|
.default_options
|
|
.merge(contents: prompt)
|
|
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if tool_call }
|
|
.to_json
|
|
end
|
|
|
|
def tool_deltas
|
|
[
|
|
{ "functionCall" => { name: "get_weather", args: {} } },
|
|
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
|
]
|
|
end
|
|
|
|
def tool_response
|
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
|
end
|
|
end
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|
subject(:endpoint) { described_class.new(model) }
|
|
|
|
fab!(:model) { Fabricate(:gemini_model, vision_enabled: true) }
|
|
|
|
fab!(:user)
|
|
|
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
|
let(:upload100x100) do
|
|
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
|
end
|
|
|
|
let(:gemini_mock) { GeminiMock.new(endpoint) }
|
|
|
|
let(:compliance) do
|
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Gemini, user)
|
|
end
|
|
|
|
let(:echo_tool) do
|
|
{
|
|
name: "echo",
|
|
description: "echo something",
|
|
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
|
|
}
|
|
end
|
|
|
|
# by default gemini is meant to use AUTO mode, however new experimental models
|
|
# appear to require this to be explicitly set
|
|
it "Explicitly specifies tool config" do
|
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
|
|
|
|
response = gemini_mock.response("World").to_json
|
|
|
|
req_body = nil
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:generateContent?key=123"
|
|
|
|
stub_request(:post, url).with(
|
|
body:
|
|
proc do |_req_body|
|
|
req_body = _req_body
|
|
true
|
|
end,
|
|
).to_return(status: 200, body: response)
|
|
|
|
response = llm.generate(prompt, user: user)
|
|
|
|
expect(response).to eq("World")
|
|
|
|
parsed = JSON.parse(req_body, symbolize_names: true)
|
|
|
|
expect(parsed[:tool_config]).to eq({ function_calling_config: { mode: "AUTO" } })
|
|
end
|
|
|
|
it "properly encodes tool calls" do
|
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:generateContent?key=123"
|
|
|
|
response_json = { "functionCall" => { name: "echo", args: { text: "<S>ydney" } } }
|
|
response = gemini_mock.response(response_json, tool_call: true).to_json
|
|
|
|
stub_request(:post, url).to_return(status: 200, body: response)
|
|
|
|
response = llm.generate(prompt, user: user)
|
|
|
|
tool =
|
|
DiscourseAi::Completions::ToolCall.new(
|
|
id: "tool_0",
|
|
name: "echo",
|
|
parameters: {
|
|
text: "<S>ydney",
|
|
},
|
|
)
|
|
|
|
expect(response).to eq(tool)
|
|
end
|
|
|
|
it "Supports Vision API" do
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are image bot",
|
|
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
|
)
|
|
|
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
|
|
|
response = gemini_mock.response("World").to_json
|
|
|
|
req_body = nil
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:generateContent?key=123"
|
|
|
|
stub_request(:post, url).with(
|
|
body:
|
|
proc do |_req_body|
|
|
req_body = _req_body
|
|
true
|
|
end,
|
|
).to_return(status: 200, body: response)
|
|
|
|
response = llm.generate(prompt, user: user)
|
|
|
|
expect(response).to eq("World")
|
|
|
|
expected_prompt = {
|
|
"generationConfig" => {
|
|
},
|
|
"safetySettings" => [
|
|
{ "category" => "HARM_CATEGORY_HARASSMENT", "threshold" => "BLOCK_NONE" },
|
|
{ "category" => "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold" => "BLOCK_NONE" },
|
|
{ "category" => "HARM_CATEGORY_HATE_SPEECH", "threshold" => "BLOCK_NONE" },
|
|
{ "category" => "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold" => "BLOCK_NONE" },
|
|
],
|
|
"contents" => [
|
|
{
|
|
"role" => "user",
|
|
"parts" => [
|
|
{ "text" => "hello" },
|
|
{ "inlineData" => { "mimeType" => "image/jpeg", "data" => encoded[0][:base64] } },
|
|
],
|
|
},
|
|
],
|
|
"systemInstruction" => {
|
|
"role" => "system",
|
|
"parts" => [{ "text" => "You are image bot" }],
|
|
},
|
|
}
|
|
|
|
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
|
end
|
|
|
|
it "Can stream tool calls correctly" do
|
|
rows = [
|
|
{
|
|
candidates: [
|
|
{
|
|
content: {
|
|
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
|
|
role: "model",
|
|
},
|
|
safetyRatings: [
|
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
|
],
|
|
},
|
|
],
|
|
usageMetadata: {
|
|
promptTokenCount: 625,
|
|
totalTokenCount: 625,
|
|
},
|
|
modelVersion: "gemini-1.5-pro-002",
|
|
},
|
|
{
|
|
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
|
|
usageMetadata: {
|
|
promptTokenCount: 625,
|
|
candidatesTokenCount: 4,
|
|
totalTokenCount: 629,
|
|
},
|
|
modelVersion: "gemini-1.5-pro-002",
|
|
},
|
|
]
|
|
|
|
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
|
|
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
|
|
|
|
output = []
|
|
|
|
stub_request(:post, url).to_return(status: 200, body: payload)
|
|
llm.generate(prompt, user: user) { |partial| output << partial }
|
|
|
|
tool_call =
|
|
DiscourseAi::Completions::ToolCall.new(
|
|
id: "tool_0",
|
|
name: "echo",
|
|
parameters: {
|
|
text: "sam<>wh!s",
|
|
},
|
|
)
|
|
|
|
expect(output).to eq([tool_call])
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(625)
|
|
expect(log.response_tokens).to eq(4)
|
|
end
|
|
|
|
it "Can correctly handle malformed responses" do
|
|
response = <<~TEXT
|
|
data: {"candidates": [{"content": {"parts": [{"text": "Certainly"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
|
|
|
data: {"candidates": [{"content": {"parts": [{"text": "! I'll create a simple \\"Hello, World!\\" page where each letter"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
|
|
|
data: {"candidates": [{"content": {"parts": [{"text": " has a different color using inline styles for simplicity. Each letter will be wrapped"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"}
|
|
|
|
data: {"candidates": [{"content": {"parts": [{"text": ""}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
|
|
|
|
data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"}
|
|
|
|
TEXT
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
|
|
|
output = []
|
|
|
|
stub_request(:post, url).to_return(status: 200, body: response)
|
|
llm.generate("Hello", user: user) { |partial| output << partial }
|
|
|
|
expect(output).to eq(
|
|
[
|
|
"Certainly",
|
|
"! I'll create a simple \"Hello, World!\" page where each letter",
|
|
" has a different color using inline styles for simplicity. Each letter will be wrapped",
|
|
],
|
|
)
|
|
end
|
|
|
|
it "Can correctly handle streamed responses even if they are chunked badly" do
|
|
data = +""
|
|
data << "da|ta: |"
|
|
data << gemini_mock.response("Hello").to_json
|
|
data << "\r\n\r\ndata: "
|
|
data << gemini_mock.response(" |World").to_json
|
|
data << "\r\n\r\ndata: "
|
|
data << gemini_mock.response(" Sam").to_json
|
|
|
|
split = data.split("|")
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
|
|
|
output = []
|
|
gemini_mock.with_chunk_array_support do
|
|
stub_request(:post, url).to_return(status: 200, body: split)
|
|
llm.generate("Hello", user: user) { |partial| output << partial }
|
|
end
|
|
|
|
expect(output.join).to eq("Hello World Sam")
|
|
end
|
|
|
|
it "can properly disable tool use with :none" do
|
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool], tool_choice: :none)
|
|
|
|
response = gemini_mock.response("I won't use any tools").to_json
|
|
|
|
req_body = nil
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:generateContent?key=123"
|
|
|
|
stub_request(:post, url).with(
|
|
body:
|
|
proc do |_req_body|
|
|
req_body = _req_body
|
|
true
|
|
end,
|
|
).to_return(status: 200, body: response)
|
|
|
|
response = llm.generate(prompt, user: user)
|
|
|
|
expect(response).to eq("I won't use any tools")
|
|
|
|
parsed = JSON.parse(req_body, symbolize_names: true)
|
|
|
|
# Verify that function_calling_config mode is set to "NONE"
|
|
expect(parsed[:tool_config]).to eq({ function_calling_config: { mode: "NONE" } })
|
|
end
|
|
|
|
it "can properly force specific tool use" do
|
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool], tool_choice: "echo")
|
|
|
|
response = gemini_mock.response("World").to_json
|
|
|
|
req_body = nil
|
|
|
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
url = "#{model.url}:generateContent?key=123"
|
|
|
|
stub_request(:post, url).with(
|
|
body:
|
|
proc do |_req_body|
|
|
req_body = _req_body
|
|
true
|
|
end,
|
|
).to_return(status: 200, body: response)
|
|
|
|
response = llm.generate(prompt, user: user)
|
|
|
|
parsed = JSON.parse(req_body, symbolize_names: true)
|
|
|
|
# Verify that function_calling_config is correctly set to ANY mode with the specified tool
|
|
expect(parsed[:tool_config]).to eq(
|
|
{ function_calling_config: { mode: "ANY", allowed_function_names: ["echo"] } },
|
|
)
|
|
end
|
|
end
|