FEATURE: Tools for models from Ollama provider (#819)

Adds support for Ollama function calling
This commit is contained in:
Hoa Nguyen 2024-10-11 07:25:53 +11:00 committed by GitHub
parent 6c4c96e83c
commit 94010a5f78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 404 additions and 20 deletions

View File

@ -31,6 +31,7 @@ class LlmModel < ActiveRecord::Base
}, },
ollama: { ollama: {
disable_system_prompt: :checkbox, disable_system_prompt: :checkbox,
enable_native_tool: :checkbox,
}, },
} }
end end

View File

@ -312,6 +312,7 @@ en:
region: "AWS Bedrock Region" region: "AWS Bedrock Region"
organization: "Optional OpenAI Organization ID" organization: "Optional OpenAI Organization ID"
disable_system_prompt: "Disable system message in prompts" disable_system_prompt: "Disable system message in prompts"
enable_native_tool: "Enable native tool support"
related_topics: related_topics:
title: "Related Topics" title: "Related Topics"

View File

@ -10,7 +10,9 @@ module DiscourseAi
end end
end end
# TODO: Add tool suppport def native_tool_support?
enable_native_tool?
end
def max_prompt_tokens def max_prompt_tokens
llm_model.max_prompt_tokens llm_model.max_prompt_tokens
@ -18,6 +20,14 @@ module DiscourseAi
private private
def tools_dialect
if enable_native_tool?
@tools_dialect ||= DiscourseAi::Completions::Dialects::OllamaTools.new(prompt.tools)
else
super
end
end
def tokenizer def tokenizer
llm_model.tokenizer_class llm_model.tokenizer_class
end end
@ -26,8 +36,28 @@ module DiscourseAi
{ role: "assistant", content: msg[:content] } { role: "assistant", content: msg[:content] }
end end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def system_msg(msg) def system_msg(msg)
{ role: "system", content: msg[:content] } msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def enable_native_tool?
return @enable_native_tool if defined?(@enable_native_tool)
@enable_native_tool = llm_model.lookup_custom_param("enable_native_tool")
end end
def user_msg(msg) def user_msg(msg)

View File

@ -0,0 +1,58 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
# TODO: Define the Tool class to be inherited by all tools.
class OllamaTools
def initialize(tools)
@raw_tools = tools
end
def instructions
"" # Noop. Tools are listed separate.
end
def translated_tools
raw_tools.map do |t|
tool = t.dup
tool[:parameters] = t[:parameters]
.to_a
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
except = %i[name required item_type]
except << :enum if p[:enum].blank?
memo[:properties][name] = p.except(*except)
memo
end
{ type: "function", function: tool }
end
end
def from_raw_tool_call(raw_message)
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
call_details[:name] = raw_message[:name]
{
role: "assistant",
content: nil,
tool_calls: [{ type: "function", function: call_details }],
}
end
def from_raw_tool(raw_message)
{ role: "tool", content: raw_message[:content], name: raw_message[:name] }
end
private
attr_reader :raw_tools
end
end
end
end

View File

@ -37,11 +37,28 @@ module DiscourseAi
URI(llm_model.url) URI(llm_model.url)
end end
def prepare_payload(prompt, model_params, _dialect) def native_tool_support?
@native_tool_support
end
def has_tool?(_response_data)
@has_function_call
end
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
# https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1
# Due to ollama enforce a 'stream: false' for tool calls, instead of complicating the code,
# we will just disable streaming for all ollama calls if native tool support is enabled
default_options default_options
.merge(model_params) .merge(model_params)
.merge(messages: prompt) .merge(messages: prompt)
.tap { |payload| payload[:stream] = false if !@streaming_mode } .tap { |payload| payload[:stream] = false if @native_tool_support || !@streaming_mode }
.tap do |payload|
payload[:tools] = dialect.tools if @native_tool_support && dialect.tools.present?
end
end end
def prepare_request(payload) def prepare_request(payload)
@ -58,7 +75,66 @@ module DiscourseAi
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true)
return if !parsed return if !parsed
parsed.dig(:message, :content) response_h = parsed.dig(:message)
@has_function_call ||= response_h.dig(:tool_calls).present?
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
@args_buffer ||= +""
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
f_name = partial.dig(:function, :name)
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args.to_json
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
argument_fragments =
json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end
function_buffer
end end
end end
end end

View File

@ -87,4 +87,5 @@ Fabricator(:ollama_model, from: :llm_model) do
api_key "ABC" api_key "ABC"
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer" tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "http://api.ollama.ai/api/chat" url "http://api.ollama.ai/api/chat"
provider_params { { enable_native_tool: true } }
end end

View File

@ -7,6 +7,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
let(:context) { DialectContext.new(described_class, model) } let(:context) { DialectContext.new(described_class, model) }
describe "#translate" do describe "#translate" do
context "when native tool support is enabled" do
it "translates a prompt written in our generic format to the Ollama format" do it "translates a prompt written in our generic format to the Ollama format" do
ollama_version = [ ollama_version = [
{ role: "system", content: context.system_insts }, { role: "system", content: context.system_insts },
@ -17,6 +18,27 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
expect(translated).to eq(ollama_version) expect(translated).to eq(ollama_version)
end end
end
context "when native tool support is disabled - XML tools" do
it "includes the instructions in the system message" do
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
DiscourseAi::Completions::Dialects::XmlTools
.any_instance
.stubs(:instructions)
.returns("Instructions")
ollama_version = [
{ role: "system", content: "#{context.system_insts}\n\nInstructions" },
{ role: "user", content: context.simple_user_input },
]
translated = context.system_user_scenario
expect(translated).to eq(ollama_version)
end
end
it "trims content if it's getting too long" do it "trims content if it's getting too long" do
model.max_prompt_tokens = 5000 model.max_prompt_tokens = 5000
@ -33,4 +55,40 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000) expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
end end
end end
describe "#tools" do
context "when native tools are enabled" do
it "returns the translated tools from the OllamaTools class" do
tool = instance_double(DiscourseAi::Completions::Dialects::OllamaTools)
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(true)
allow(tool).to receive(:translated_tools)
allow(DiscourseAi::Completions::Dialects::OllamaTools).to receive(:new).and_return(tool)
context.dialect_tools
expect(DiscourseAi::Completions::Dialects::OllamaTools).to have_received(:new).with(
context.prompt.tools,
)
expect(tool).to have_received(:translated_tools)
end
end
context "when native tools are disabled" do
it "returns the translated tools from the XmlTools class" do
tool = instance_double(DiscourseAi::Completions::Dialects::XmlTools)
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
allow(tool).to receive(:translated_tools)
allow(DiscourseAi::Completions::Dialects::XmlTools).to receive(:new).and_return(tool)
context.dialect_tools
expect(DiscourseAi::Completions::Dialects::XmlTools).to have_received(:new).with(
context.prompt.tools,
)
expect(tool).to have_received(:translated_tools)
end
end
end
end end

View File

@ -0,0 +1,112 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::OllamaTools do
describe "#translated_tools" do
it "translates a tool from our generic format to the Ollama format" do
tools = [
{
name: "github_file_content",
description: "Retrieves the content of specified GitHub files",
parameters: [
{
name: "repo_name",
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
type: "string",
required: true,
},
{
name: "file_paths",
description: "The paths of the files to retrieve within the repository",
type: "array",
item_type: "string",
required: true,
},
{
name: "branch",
description: "The branch or commit SHA to retrieve the files from (default: 'main')",
type: "string",
required: false,
},
],
},
]
ollama_tools = described_class.new(tools)
translated_tools = ollama_tools.translated_tools
expect(translated_tools).to eq(
[
{
type: "function",
function: {
name: "github_file_content",
description: "Retrieves the content of specified GitHub files",
parameters: {
type: "object",
properties: {
"repo_name" => {
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
type: "string",
},
"file_paths" => {
description: "The paths of the files to retrieve within the repository",
type: "array",
},
"branch" => {
description:
"The branch or commit SHA to retrieve the files from (default: 'main')",
type: "string",
},
},
required: %w[repo_name file_paths],
},
},
},
],
)
end
end
describe "#from_raw_tool_call" do
it "converts a raw tool call to the Ollama tool format" do
raw_message = {
content: '{"repo_name":"discourse/discourse","file_paths":["README.md"],"branch":"main"}',
}
ollama_tools = described_class.new([])
tool_call = ollama_tools.from_raw_tool_call(raw_message)
expect(tool_call).to eq(
{
role: "assistant",
content: nil,
tool_calls: [
{
type: "function",
function: {
repo_name: "discourse/discourse",
file_paths: ["README.md"],
branch: "main",
name: nil,
},
},
],
},
)
end
end
describe "#from_raw_tool" do
it "converts a raw tool to the Ollama tool format" do
raw_message = { content: "Hello, world!", name: "github_file_content" }
ollama_tools = described_class.new([])
tool = ollama_tools.from_raw_tool(raw_message)
expect(tool).to eq({ role: "tool", content: "Hello, world!", name: "github_file_content" })
end
end
end

View File

@ -3,8 +3,13 @@
require_relative "endpoint_compliance" require_relative "endpoint_compliance"
class OllamaMock < EndpointMock class OllamaMock < EndpointMock
def response(content) def response(content, tool_call: false)
message_content = { content: content } message_content =
if tool_call
{ content: "", tool_calls: [content] }
else
{ content: content }
end
{ {
created_at: "2024-09-25T06:47:21.283028Z", created_at: "2024-09-25T06:47:21.283028Z",
@ -21,11 +26,11 @@ class OllamaMock < EndpointMock
} }
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "http://api.ollama.ai/api/chat") .stub_request(:post, "http://api.ollama.ai/api/chat")
.with(body: request_body(prompt)) .with(body: request_body(prompt, tool_call: tool_call))
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end end
def stream_line(delta) def stream_line(delta)
@ -71,14 +76,50 @@ class OllamaMock < EndpointMock
WebMock WebMock
.stub_request(:post, "http://api.ollama.ai/api/chat") .stub_request(:post, "http://api.ollama.ai/api/chat")
.with(body: request_body(prompt, stream: true)) .with(body: request_body(prompt))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
yield if block_given? yield if block_given?
end end
def request_body(prompt, stream: false) def tool_response
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json { function: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } } }
end
def tool_payload
{
type: "function",
function: {
name: "get_weather",
description: "Get the weather in a city",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "the city name",
},
unit: {
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
},
required: %w[location unit],
},
},
}
end
def request_body(prompt, tool_call: false)
model
.default_options
.merge(messages: prompt)
.tap do |b|
b[:stream] = false
b[:tools] = [tool_payload] if tool_call
end
.to_json
end end
end end
@ -100,6 +141,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
compliance.regular_mode_simple_prompt(ollama_mock) compliance.regular_mode_simple_prompt(ollama_mock)
end end
end end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(ollama_mock)
end
end
end end
describe "when using streaming mode" do describe "when using streaming mode" do