FEATURE: Tools for models from Ollama provider (#819)
Adds support for Ollama function calling
This commit is contained in:
parent
6c4c96e83c
commit
94010a5f78
|
@ -31,6 +31,7 @@ class LlmModel < ActiveRecord::Base
|
||||||
},
|
},
|
||||||
ollama: {
|
ollama: {
|
||||||
disable_system_prompt: :checkbox,
|
disable_system_prompt: :checkbox,
|
||||||
|
enable_native_tool: :checkbox,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
|
@ -312,6 +312,7 @@ en:
|
||||||
region: "AWS Bedrock Region"
|
region: "AWS Bedrock Region"
|
||||||
organization: "Optional OpenAI Organization ID"
|
organization: "Optional OpenAI Organization ID"
|
||||||
disable_system_prompt: "Disable system message in prompts"
|
disable_system_prompt: "Disable system message in prompts"
|
||||||
|
enable_native_tool: "Enable native tool support"
|
||||||
|
|
||||||
related_topics:
|
related_topics:
|
||||||
title: "Related Topics"
|
title: "Related Topics"
|
||||||
|
|
|
@ -10,7 +10,9 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: Add tool suppport
|
def native_tool_support?
|
||||||
|
enable_native_tool?
|
||||||
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
llm_model.max_prompt_tokens
|
llm_model.max_prompt_tokens
|
||||||
|
@ -18,6 +20,14 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
def tools_dialect
|
||||||
|
if enable_native_tool?
|
||||||
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::OllamaTools.new(prompt.tools)
|
||||||
|
else
|
||||||
|
super
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
llm_model.tokenizer_class
|
llm_model.tokenizer_class
|
||||||
end
|
end
|
||||||
|
@ -26,8 +36,28 @@ module DiscourseAi
|
||||||
{ role: "assistant", content: msg[:content] }
|
{ role: "assistant", content: msg[:content] }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def tool_call_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool_call(msg)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool(msg)
|
||||||
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
{ role: "system", content: msg[:content] }
|
msg = { role: "system", content: msg[:content] }
|
||||||
|
|
||||||
|
if tools_dialect.instructions.present?
|
||||||
|
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
||||||
|
end
|
||||||
|
|
||||||
|
msg
|
||||||
|
end
|
||||||
|
|
||||||
|
def enable_native_tool?
|
||||||
|
return @enable_native_tool if defined?(@enable_native_tool)
|
||||||
|
|
||||||
|
@enable_native_tool = llm_model.lookup_custom_param("enable_native_tool")
|
||||||
end
|
end
|
||||||
|
|
||||||
def user_msg(msg)
|
def user_msg(msg)
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
# TODO: Define the Tool class to be inherited by all tools.
|
||||||
|
class OllamaTools
|
||||||
|
def initialize(tools)
|
||||||
|
@raw_tools = tools
|
||||||
|
end
|
||||||
|
|
||||||
|
def instructions
|
||||||
|
"" # Noop. Tools are listed separate.
|
||||||
|
end
|
||||||
|
|
||||||
|
def translated_tools
|
||||||
|
raw_tools.map do |t|
|
||||||
|
tool = t.dup
|
||||||
|
|
||||||
|
tool[:parameters] = t[:parameters]
|
||||||
|
.to_a
|
||||||
|
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
|
||||||
|
name = p[:name]
|
||||||
|
memo[:required] << name if p[:required]
|
||||||
|
|
||||||
|
except = %i[name required item_type]
|
||||||
|
except << :enum if p[:enum].blank?
|
||||||
|
|
||||||
|
memo[:properties][name] = p.except(*except)
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
|
||||||
|
{ type: "function", function: tool }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def from_raw_tool_call(raw_message)
|
||||||
|
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
|
||||||
|
call_details[:name] = raw_message[:name]
|
||||||
|
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: nil,
|
||||||
|
tool_calls: [{ type: "function", function: call_details }],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def from_raw_tool(raw_message)
|
||||||
|
{ role: "tool", content: raw_message[:content], name: raw_message[:name] }
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
attr_reader :raw_tools
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -37,11 +37,28 @@ module DiscourseAi
|
||||||
URI(llm_model.url)
|
URI(llm_model.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def native_tool_support?
|
||||||
|
@native_tool_support
|
||||||
|
end
|
||||||
|
|
||||||
|
def has_tool?(_response_data)
|
||||||
|
@has_function_call
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@native_tool_support = dialect.native_tool_support?
|
||||||
|
|
||||||
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1
|
||||||
|
# Due to ollama enforce a 'stream: false' for tool calls, instead of complicating the code,
|
||||||
|
# we will just disable streaming for all ollama calls if native tool support is enabled
|
||||||
|
|
||||||
default_options
|
default_options
|
||||||
.merge(model_params)
|
.merge(model_params)
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap { |payload| payload[:stream] = false if !@streaming_mode }
|
.tap { |payload| payload[:stream] = false if @native_tool_support || !@streaming_mode }
|
||||||
|
.tap do |payload|
|
||||||
|
payload[:tools] = dialect.tools if @native_tool_support && dialect.tools.present?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
|
@ -58,7 +75,66 @@ module DiscourseAi
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
return if !parsed
|
return if !parsed
|
||||||
|
|
||||||
parsed.dig(:message, :content)
|
response_h = parsed.dig(:message)
|
||||||
|
|
||||||
|
@has_function_call ||= response_h.dig(:tool_calls).present?
|
||||||
|
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||||
|
@args_buffer ||= +""
|
||||||
|
|
||||||
|
if @streaming_mode
|
||||||
|
return function_buffer if !partial
|
||||||
|
else
|
||||||
|
partial = payload
|
||||||
|
end
|
||||||
|
|
||||||
|
f_name = partial.dig(:function, :name)
|
||||||
|
|
||||||
|
@current_function ||= function_buffer.at("invoke")
|
||||||
|
|
||||||
|
if f_name
|
||||||
|
current_name = function_buffer.at("tool_name").content
|
||||||
|
|
||||||
|
if current_name.blank?
|
||||||
|
# first call
|
||||||
|
else
|
||||||
|
# we have a previous function, so we need to add a noop
|
||||||
|
@args_buffer = +""
|
||||||
|
@current_function =
|
||||||
|
function_buffer.at("function_calls").add_child(
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@current_function.at("tool_name").content = f_name if f_name
|
||||||
|
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||||
|
|
||||||
|
args = partial.dig(:function, :arguments)
|
||||||
|
|
||||||
|
# allow for SPACE within arguments
|
||||||
|
if args && args != ""
|
||||||
|
@args_buffer << args.to_json
|
||||||
|
|
||||||
|
begin
|
||||||
|
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
||||||
|
|
||||||
|
argument_fragments =
|
||||||
|
json_args.reduce(+"") do |memo, (arg_name, value)|
|
||||||
|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||||
|
end
|
||||||
|
argument_fragments << "\n"
|
||||||
|
|
||||||
|
@current_function.at("parameters").children =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||||
|
rescue JSON::ParserError
|
||||||
|
return function_buffer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function_buffer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -87,4 +87,5 @@ Fabricator(:ollama_model, from: :llm_model) do
|
||||||
api_key "ABC"
|
api_key "ABC"
|
||||||
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||||
url "http://api.ollama.ai/api/chat"
|
url "http://api.ollama.ai/api/chat"
|
||||||
|
provider_params { { enable_native_tool: true } }
|
||||||
end
|
end
|
||||||
|
|
|
@ -7,6 +7,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
||||||
let(:context) { DialectContext.new(described_class, model) }
|
let(:context) { DialectContext.new(described_class, model) }
|
||||||
|
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
|
context "when native tool support is enabled" do
|
||||||
it "translates a prompt written in our generic format to the Ollama format" do
|
it "translates a prompt written in our generic format to the Ollama format" do
|
||||||
ollama_version = [
|
ollama_version = [
|
||||||
{ role: "system", content: context.system_insts },
|
{ role: "system", content: context.system_insts },
|
||||||
|
@ -17,6 +18,27 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
||||||
|
|
||||||
expect(translated).to eq(ollama_version)
|
expect(translated).to eq(ollama_version)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when native tool support is disabled - XML tools" do
|
||||||
|
it "includes the instructions in the system message" do
|
||||||
|
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
|
||||||
|
|
||||||
|
DiscourseAi::Completions::Dialects::XmlTools
|
||||||
|
.any_instance
|
||||||
|
.stubs(:instructions)
|
||||||
|
.returns("Instructions")
|
||||||
|
|
||||||
|
ollama_version = [
|
||||||
|
{ role: "system", content: "#{context.system_insts}\n\nInstructions" },
|
||||||
|
{ role: "user", content: context.simple_user_input },
|
||||||
|
]
|
||||||
|
|
||||||
|
translated = context.system_user_scenario
|
||||||
|
|
||||||
|
expect(translated).to eq(ollama_version)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
it "trims content if it's getting too long" do
|
it "trims content if it's getting too long" do
|
||||||
model.max_prompt_tokens = 5000
|
model.max_prompt_tokens = 5000
|
||||||
|
@ -33,4 +55,40 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
||||||
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
|
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
context "when native tools are enabled" do
|
||||||
|
it "returns the translated tools from the OllamaTools class" do
|
||||||
|
tool = instance_double(DiscourseAi::Completions::Dialects::OllamaTools)
|
||||||
|
|
||||||
|
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(true)
|
||||||
|
allow(tool).to receive(:translated_tools)
|
||||||
|
allow(DiscourseAi::Completions::Dialects::OllamaTools).to receive(:new).and_return(tool)
|
||||||
|
|
||||||
|
context.dialect_tools
|
||||||
|
|
||||||
|
expect(DiscourseAi::Completions::Dialects::OllamaTools).to have_received(:new).with(
|
||||||
|
context.prompt.tools,
|
||||||
|
)
|
||||||
|
expect(tool).to have_received(:translated_tools)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when native tools are disabled" do
|
||||||
|
it "returns the translated tools from the XmlTools class" do
|
||||||
|
tool = instance_double(DiscourseAi::Completions::Dialects::XmlTools)
|
||||||
|
|
||||||
|
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
|
||||||
|
allow(tool).to receive(:translated_tools)
|
||||||
|
allow(DiscourseAi::Completions::Dialects::XmlTools).to receive(:new).and_return(tool)
|
||||||
|
|
||||||
|
context.dialect_tools
|
||||||
|
|
||||||
|
expect(DiscourseAi::Completions::Dialects::XmlTools).to have_received(:new).with(
|
||||||
|
context.prompt.tools,
|
||||||
|
)
|
||||||
|
expect(tool).to have_received(:translated_tools)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "dialect_context"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Dialects::OllamaTools do
|
||||||
|
describe "#translated_tools" do
|
||||||
|
it "translates a tool from our generic format to the Ollama format" do
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
name: "github_file_content",
|
||||||
|
description: "Retrieves the content of specified GitHub files",
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: "repo_name",
|
||||||
|
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file_paths",
|
||||||
|
description: "The paths of the files to retrieve within the repository",
|
||||||
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "branch",
|
||||||
|
description: "The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||||
|
type: "string",
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
ollama_tools = described_class.new(tools)
|
||||||
|
|
||||||
|
translated_tools = ollama_tools.translated_tools
|
||||||
|
|
||||||
|
expect(translated_tools).to eq(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "github_file_content",
|
||||||
|
description: "Retrieves the content of specified GitHub files",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
"repo_name" => {
|
||||||
|
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
"file_paths" => {
|
||||||
|
description: "The paths of the files to retrieve within the repository",
|
||||||
|
type: "array",
|
||||||
|
},
|
||||||
|
"branch" => {
|
||||||
|
description:
|
||||||
|
"The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: %w[repo_name file_paths],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#from_raw_tool_call" do
|
||||||
|
it "converts a raw tool call to the Ollama tool format" do
|
||||||
|
raw_message = {
|
||||||
|
content: '{"repo_name":"discourse/discourse","file_paths":["README.md"],"branch":"main"}',
|
||||||
|
}
|
||||||
|
|
||||||
|
ollama_tools = described_class.new([])
|
||||||
|
tool_call = ollama_tools.from_raw_tool_call(raw_message)
|
||||||
|
|
||||||
|
expect(tool_call).to eq(
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: nil,
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
repo_name: "discourse/discourse",
|
||||||
|
file_paths: ["README.md"],
|
||||||
|
branch: "main",
|
||||||
|
name: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#from_raw_tool" do
|
||||||
|
it "converts a raw tool to the Ollama tool format" do
|
||||||
|
raw_message = { content: "Hello, world!", name: "github_file_content" }
|
||||||
|
|
||||||
|
ollama_tools = described_class.new([])
|
||||||
|
tool = ollama_tools.from_raw_tool(raw_message)
|
||||||
|
|
||||||
|
expect(tool).to eq({ role: "tool", content: "Hello, world!", name: "github_file_content" })
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -3,8 +3,13 @@
|
||||||
require_relative "endpoint_compliance"
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
class OllamaMock < EndpointMock
|
class OllamaMock < EndpointMock
|
||||||
def response(content)
|
def response(content, tool_call: false)
|
||||||
message_content = { content: content }
|
message_content =
|
||||||
|
if tool_call
|
||||||
|
{ content: "", tool_calls: [content] }
|
||||||
|
else
|
||||||
|
{ content: content }
|
||||||
|
end
|
||||||
|
|
||||||
{
|
{
|
||||||
created_at: "2024-09-25T06:47:21.283028Z",
|
created_at: "2024-09-25T06:47:21.283028Z",
|
||||||
|
@ -21,11 +26,11 @@ class OllamaMock < EndpointMock
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||||
.with(body: request_body(prompt))
|
.with(body: request_body(prompt, tool_call: tool_call))
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_line(delta)
|
def stream_line(delta)
|
||||||
|
@ -71,14 +76,50 @@ class OllamaMock < EndpointMock
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||||
.with(body: request_body(prompt, stream: true))
|
.with(body: request_body(prompt))
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
|
|
||||||
yield if block_given?
|
yield if block_given?
|
||||||
end
|
end
|
||||||
|
|
||||||
def request_body(prompt, stream: false)
|
def tool_response
|
||||||
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json
|
{ function: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } } }
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_payload
|
||||||
|
{
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
location: {
|
||||||
|
type: "string",
|
||||||
|
description: "the city name",
|
||||||
|
},
|
||||||
|
unit: {
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: %w[location unit],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def request_body(prompt, tool_call: false)
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(messages: prompt)
|
||||||
|
.tap do |b|
|
||||||
|
b[:stream] = false
|
||||||
|
b[:tools] = [tool_payload] if tool_call
|
||||||
|
end
|
||||||
|
.to_json
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -100,6 +141,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
||||||
compliance.regular_mode_simple_prompt(ollama_mock)
|
compliance.regular_mode_simple_prompt(ollama_mock)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
context "with tools" do
|
||||||
|
it "returns a function invocation" do
|
||||||
|
compliance.regular_mode_tools(ollama_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "when using streaming mode" do
|
describe "when using streaming mode" do
|
||||||
|
|
Loading…
Reference in New Issue