mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-03 11:53:25 +00:00
This PR adds support for disabling further tool calls by setting tool_choice to :none across all supported LLM providers: - OpenAI: Uses "none" tool_choice parameter - Anthropic: Uses {type: "none"} and adds a prefill message to prevent confusion - Gemini: Sets function_calling_config mode to "NONE" - AWS Bedrock: Doesn't natively support tool disabling, so adds a prefill message We previously used to disable tool calls by simply removing tool definitions, but this would cause errors with some providers. This implementation uses the supported method appropriate for each provider while providing a fallback for Bedrock. Co-authored-by: Natalie Tay <natalie.tay@gmail.com> * remove stray puts * cleaner chain breaker for last tool call (works in thinking) remove unused code * improve test --------- Co-authored-by: Natalie Tay <natalie.tay@gmail.com>
237 lines
6.6 KiB
Ruby
237 lines
6.6 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Gemini < Base
|
|
def self.can_contact?(model_provider)
|
|
model_provider == "google"
|
|
end
|
|
|
|
def default_options
|
|
# the default setting is a problem, it blocks too much
|
|
categories = %w[HARASSMENT SEXUALLY_EXPLICIT HATE_SPEECH DANGEROUS_CONTENT]
|
|
|
|
safety_settings =
|
|
categories.map do |category|
|
|
{ category: "HARM_CATEGORY_#{category}", threshold: "BLOCK_NONE" }
|
|
end
|
|
|
|
{ generationConfig: {}, safetySettings: safety_settings }
|
|
end
|
|
|
|
def normalize_model_params(model_params)
|
|
model_params = model_params.dup
|
|
|
|
if model_params[:stop_sequences]
|
|
model_params[:stopSequences] = model_params.delete(:stop_sequences)
|
|
end
|
|
|
|
if model_params[:max_tokens]
|
|
model_params[:maxOutputTokens] = model_params.delete(:max_tokens)
|
|
end
|
|
|
|
model_params[:topP] = model_params.delete(:top_p) if model_params[:top_p]
|
|
|
|
# temperature already supported
|
|
|
|
model_params
|
|
end
|
|
|
|
def provider_id
|
|
AiApiAuditLog::Provider::Gemini
|
|
end
|
|
|
|
private
|
|
|
|
def model_uri
|
|
url = llm_model.url
|
|
key = llm_model.api_key
|
|
|
|
if @streaming_mode
|
|
url = "#{url}:streamGenerateContent?key=#{key}&alt=sse"
|
|
else
|
|
url = "#{url}:generateContent?key=#{key}"
|
|
end
|
|
|
|
URI(url)
|
|
end
|
|
|
|
def prepare_payload(prompt, model_params, dialect)
|
|
@native_tool_support = dialect.native_tool_support?
|
|
|
|
tools = dialect.tools if @native_tool_support
|
|
|
|
payload = default_options.merge(contents: prompt[:messages])
|
|
payload[:systemInstruction] = {
|
|
role: "system",
|
|
parts: [{ text: prompt[:system_instruction].to_s }],
|
|
} if prompt[:system_instruction].present?
|
|
if tools.present?
|
|
payload[:tools] = tools
|
|
|
|
function_calling_config = { mode: "AUTO" }
|
|
if dialect.tool_choice.present?
|
|
if dialect.tool_choice == :none
|
|
function_calling_config = { mode: "NONE" }
|
|
else
|
|
function_calling_config = {
|
|
mode: "ANY",
|
|
allowed_function_names: [dialect.tool_choice],
|
|
}
|
|
end
|
|
end
|
|
|
|
payload[:tool_config] = { function_calling_config: function_calling_config }
|
|
end
|
|
payload[:generationConfig].merge!(model_params) if model_params.present?
|
|
payload
|
|
end
|
|
|
|
def prepare_request(payload)
|
|
headers = { "Content-Type" => "application/json" }
|
|
|
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
|
end
|
|
|
|
def extract_completion_from(response_raw)
|
|
parsed =
|
|
if @streaming_mode
|
|
response_raw
|
|
else
|
|
JSON.parse(response_raw, symbolize_names: true)
|
|
end
|
|
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
|
|
|
if response_h
|
|
@has_function_call ||= response_h.dig(:functionCall).present?
|
|
@has_function_call ? response_h.dig(:functionCall) : response_h.dig(:text)
|
|
end
|
|
end
|
|
|
|
class GeminiStreamingDecoder
|
|
def initialize
|
|
@buffer = +""
|
|
end
|
|
|
|
def decode(str)
|
|
@buffer << str
|
|
|
|
lines = @buffer.split(/\r?\n\r?\n/)
|
|
|
|
keep_last = false
|
|
|
|
decoded =
|
|
lines
|
|
.map do |line|
|
|
if line.start_with?("data: {")
|
|
begin
|
|
JSON.parse(line[6..-1], symbolize_names: true)
|
|
rescue JSON::ParserError
|
|
keep_last = line
|
|
nil
|
|
end
|
|
else
|
|
keep_last = line
|
|
nil
|
|
end
|
|
end
|
|
.compact
|
|
|
|
if keep_last
|
|
@buffer = +(keep_last)
|
|
else
|
|
@buffer = +""
|
|
end
|
|
|
|
decoded
|
|
end
|
|
end
|
|
|
|
def decode(chunk)
|
|
json = JSON.parse(chunk, symbolize_names: true)
|
|
|
|
idx = -1
|
|
json
|
|
.dig(:candidates, 0, :content, :parts)
|
|
.map do |part|
|
|
if part[:functionCall]
|
|
idx += 1
|
|
ToolCall.new(
|
|
id: "tool_#{idx}",
|
|
name: part[:functionCall][:name],
|
|
parameters: part[:functionCall][:args],
|
|
)
|
|
else
|
|
part = part[:text]
|
|
if part != ""
|
|
part
|
|
else
|
|
nil
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def decode_chunk(chunk)
|
|
@tool_index ||= -1
|
|
streaming_decoder
|
|
.decode(chunk)
|
|
.map do |parsed|
|
|
update_usage(parsed)
|
|
parts = parsed.dig(:candidates, 0, :content, :parts)
|
|
parts&.map do |part|
|
|
if part[:text]
|
|
part = part[:text]
|
|
if part != ""
|
|
part
|
|
else
|
|
nil
|
|
end
|
|
elsif part[:functionCall]
|
|
@tool_index += 1
|
|
ToolCall.new(
|
|
id: "tool_#{@tool_index}",
|
|
name: part[:functionCall][:name],
|
|
parameters: part[:functionCall][:args],
|
|
)
|
|
end
|
|
end
|
|
end
|
|
.flatten
|
|
.compact
|
|
end
|
|
|
|
def update_usage(parsed)
|
|
usage = parsed.dig(:usageMetadata)
|
|
if usage
|
|
if prompt_token_count = usage[:promptTokenCount]
|
|
@prompt_token_count = prompt_token_count
|
|
end
|
|
if candidate_token_count = usage[:candidatesTokenCount]
|
|
@candidate_token_count = candidate_token_count
|
|
end
|
|
end
|
|
end
|
|
|
|
def final_log_update(log)
|
|
log.request_tokens = @prompt_token_count if @prompt_token_count
|
|
log.response_tokens = @candidate_token_count if @candidate_token_count
|
|
end
|
|
|
|
def streaming_decoder
|
|
@decoder ||= GeminiStreamingDecoder.new
|
|
end
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
prompt.to_s
|
|
end
|
|
|
|
def xml_tools_enabled?
|
|
!@native_tool_support
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|