FEATURE: Tools for models from Ollama provider (#819)
Adds support for Ollama function calling
This commit is contained in:
parent
6c4c96e83c
commit
94010a5f78
|
@ -31,6 +31,7 @@ class LlmModel < ActiveRecord::Base
|
|||
},
|
||||
ollama: {
|
||||
disable_system_prompt: :checkbox,
|
||||
enable_native_tool: :checkbox,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
|
|
@ -312,6 +312,7 @@ en:
|
|||
region: "AWS Bedrock Region"
|
||||
organization: "Optional OpenAI Organization ID"
|
||||
disable_system_prompt: "Disable system message in prompts"
|
||||
enable_native_tool: "Enable native tool support"
|
||||
|
||||
related_topics:
|
||||
title: "Related Topics"
|
||||
|
|
|
@ -10,7 +10,9 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
# TODO: Add tool suppport
|
||||
def native_tool_support?
|
||||
enable_native_tool?
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
llm_model.max_prompt_tokens
|
||||
|
@ -18,6 +20,14 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def tools_dialect
|
||||
if enable_native_tool?
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::OllamaTools.new(prompt.tools)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model.tokenizer_class
|
||||
end
|
||||
|
@ -26,8 +36,28 @@ module DiscourseAi
|
|||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
tools_dialect.from_raw_tool_call(msg)
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
tools_dialect.from_raw_tool(msg)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
{ role: "system", content: msg[:content] }
|
||||
msg = { role: "system", content: msg[:content] }
|
||||
|
||||
if tools_dialect.instructions.present?
|
||||
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
||||
end
|
||||
|
||||
msg
|
||||
end
|
||||
|
||||
def enable_native_tool?
|
||||
return @enable_native_tool if defined?(@enable_native_tool)
|
||||
|
||||
@enable_native_tool = llm_model.lookup_custom_param("enable_native_tool")
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
# TODO: Define the Tool class to be inherited by all tools.
|
||||
class OllamaTools
|
||||
def initialize(tools)
|
||||
@raw_tools = tools
|
||||
end
|
||||
|
||||
def instructions
|
||||
"" # Noop. Tools are listed separate.
|
||||
end
|
||||
|
||||
def translated_tools
|
||||
raw_tools.map do |t|
|
||||
tool = t.dup
|
||||
|
||||
tool[:parameters] = t[:parameters]
|
||||
.to_a
|
||||
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
|
||||
name = p[:name]
|
||||
memo[:required] << name if p[:required]
|
||||
|
||||
except = %i[name required item_type]
|
||||
except << :enum if p[:enum].blank?
|
||||
|
||||
memo[:properties][name] = p.except(*except)
|
||||
memo
|
||||
end
|
||||
|
||||
{ type: "function", function: tool }
|
||||
end
|
||||
end
|
||||
|
||||
def from_raw_tool_call(raw_message)
|
||||
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
|
||||
call_details[:name] = raw_message[:name]
|
||||
|
||||
{
|
||||
role: "assistant",
|
||||
content: nil,
|
||||
tool_calls: [{ type: "function", function: call_details }],
|
||||
}
|
||||
end
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
{ role: "tool", content: raw_message[:content], name: raw_message[:name] }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :raw_tools
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -37,11 +37,28 @@ module DiscourseAi
|
|||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
def native_tool_support?
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
@has_function_call
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1
|
||||
# Due to ollama enforce a 'stream: false' for tool calls, instead of complicating the code,
|
||||
# we will just disable streaming for all ollama calls if native tool support is enabled
|
||||
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(messages: prompt)
|
||||
.tap { |payload| payload[:stream] = false if !@streaming_mode }
|
||||
.tap { |payload| payload[:stream] = false if @native_tool_support || !@streaming_mode }
|
||||
.tap do |payload|
|
||||
payload[:tools] = dialect.tools if @native_tool_support && dialect.tools.present?
|
||||
end
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -58,7 +75,66 @@ module DiscourseAi
|
|||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
return if !parsed
|
||||
|
||||
parsed.dig(:message, :content)
|
||||
response_h = parsed.dig(:message)
|
||||
|
||||
@has_function_call ||= response_h.dig(:tool_calls).present?
|
||||
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||
@args_buffer ||= +""
|
||||
|
||||
if @streaming_mode
|
||||
return function_buffer if !partial
|
||||
else
|
||||
partial = payload
|
||||
end
|
||||
|
||||
f_name = partial.dig(:function, :name)
|
||||
|
||||
@current_function ||= function_buffer.at("invoke")
|
||||
|
||||
if f_name
|
||||
current_name = function_buffer.at("tool_name").content
|
||||
|
||||
if current_name.blank?
|
||||
# first call
|
||||
else
|
||||
# we have a previous function, so we need to add a noop
|
||||
@args_buffer = +""
|
||||
@current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@current_function.at("tool_name").content = f_name if f_name
|
||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||
|
||||
args = partial.dig(:function, :arguments)
|
||||
|
||||
# allow for SPACE within arguments
|
||||
if args && args != ""
|
||||
@args_buffer << args.to_json
|
||||
|
||||
begin
|
||||
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
||||
|
||||
argument_fragments =
|
||||
json_args.reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
@current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
rescue JSON::ParserError
|
||||
return function_buffer
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -87,4 +87,5 @@ Fabricator(:ollama_model, from: :llm_model) do
|
|||
api_key "ABC"
|
||||
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||
url "http://api.ollama.ai/api/chat"
|
||||
provider_params { { enable_native_tool: true } }
|
||||
end
|
||||
|
|
|
@ -7,6 +7,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
|||
let(:context) { DialectContext.new(described_class, model) }
|
||||
|
||||
describe "#translate" do
|
||||
context "when native tool support is enabled" do
|
||||
it "translates a prompt written in our generic format to the Ollama format" do
|
||||
ollama_version = [
|
||||
{ role: "system", content: context.system_insts },
|
||||
|
@ -17,6 +18,27 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
|||
|
||||
expect(translated).to eq(ollama_version)
|
||||
end
|
||||
end
|
||||
|
||||
context "when native tool support is disabled - XML tools" do
|
||||
it "includes the instructions in the system message" do
|
||||
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
|
||||
|
||||
DiscourseAi::Completions::Dialects::XmlTools
|
||||
.any_instance
|
||||
.stubs(:instructions)
|
||||
.returns("Instructions")
|
||||
|
||||
ollama_version = [
|
||||
{ role: "system", content: "#{context.system_insts}\n\nInstructions" },
|
||||
{ role: "user", content: context.simple_user_input },
|
||||
]
|
||||
|
||||
translated = context.system_user_scenario
|
||||
|
||||
expect(translated).to eq(ollama_version)
|
||||
end
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
model.max_prompt_tokens = 5000
|
||||
|
@ -33,4 +55,40 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
|||
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
context "when native tools are enabled" do
|
||||
it "returns the translated tools from the OllamaTools class" do
|
||||
tool = instance_double(DiscourseAi::Completions::Dialects::OllamaTools)
|
||||
|
||||
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(true)
|
||||
allow(tool).to receive(:translated_tools)
|
||||
allow(DiscourseAi::Completions::Dialects::OllamaTools).to receive(:new).and_return(tool)
|
||||
|
||||
context.dialect_tools
|
||||
|
||||
expect(DiscourseAi::Completions::Dialects::OllamaTools).to have_received(:new).with(
|
||||
context.prompt.tools,
|
||||
)
|
||||
expect(tool).to have_received(:translated_tools)
|
||||
end
|
||||
end
|
||||
|
||||
context "when native tools are disabled" do
|
||||
it "returns the translated tools from the XmlTools class" do
|
||||
tool = instance_double(DiscourseAi::Completions::Dialects::XmlTools)
|
||||
|
||||
allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false)
|
||||
allow(tool).to receive(:translated_tools)
|
||||
allow(DiscourseAi::Completions::Dialects::XmlTools).to receive(:new).and_return(tool)
|
||||
|
||||
context.dialect_tools
|
||||
|
||||
expect(DiscourseAi::Completions::Dialects::XmlTools).to have_received(:new).with(
|
||||
context.prompt.tools,
|
||||
)
|
||||
expect(tool).to have_received(:translated_tools)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::OllamaTools do
|
||||
describe "#translated_tools" do
|
||||
it "translates a tool from our generic format to the Ollama format" do
|
||||
tools = [
|
||||
{
|
||||
name: "github_file_content",
|
||||
description: "Retrieves the content of specified GitHub files",
|
||||
parameters: [
|
||||
{
|
||||
name: "repo_name",
|
||||
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "file_paths",
|
||||
description: "The paths of the files to retrieve within the repository",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "branch",
|
||||
description: "The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||
type: "string",
|
||||
required: false,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
ollama_tools = described_class.new(tools)
|
||||
|
||||
translated_tools = ollama_tools.translated_tools
|
||||
|
||||
expect(translated_tools).to eq(
|
||||
[
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "github_file_content",
|
||||
description: "Retrieves the content of specified GitHub files",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
"repo_name" => {
|
||||
description: "The name of the GitHub repository (e.g., 'discourse/discourse')",
|
||||
type: "string",
|
||||
},
|
||||
"file_paths" => {
|
||||
description: "The paths of the files to retrieve within the repository",
|
||||
type: "array",
|
||||
},
|
||||
"branch" => {
|
||||
description:
|
||||
"The branch or commit SHA to retrieve the files from (default: 'main')",
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: %w[repo_name file_paths],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#from_raw_tool_call" do
|
||||
it "converts a raw tool call to the Ollama tool format" do
|
||||
raw_message = {
|
||||
content: '{"repo_name":"discourse/discourse","file_paths":["README.md"],"branch":"main"}',
|
||||
}
|
||||
|
||||
ollama_tools = described_class.new([])
|
||||
tool_call = ollama_tools.from_raw_tool_call(raw_message)
|
||||
|
||||
expect(tool_call).to eq(
|
||||
{
|
||||
role: "assistant",
|
||||
content: nil,
|
||||
tool_calls: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
repo_name: "discourse/discourse",
|
||||
file_paths: ["README.md"],
|
||||
branch: "main",
|
||||
name: nil,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#from_raw_tool" do
|
||||
it "converts a raw tool to the Ollama tool format" do
|
||||
raw_message = { content: "Hello, world!", name: "github_file_content" }
|
||||
|
||||
ollama_tools = described_class.new([])
|
||||
tool = ollama_tools.from_raw_tool(raw_message)
|
||||
|
||||
expect(tool).to eq({ role: "tool", content: "Hello, world!", name: "github_file_content" })
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,8 +3,13 @@
|
|||
require_relative "endpoint_compliance"
|
||||
|
||||
class OllamaMock < EndpointMock
|
||||
def response(content)
|
||||
message_content = { content: content }
|
||||
def response(content, tool_call: false)
|
||||
message_content =
|
||||
if tool_call
|
||||
{ content: "", tool_calls: [content] }
|
||||
else
|
||||
{ content: content }
|
||||
end
|
||||
|
||||
{
|
||||
created_at: "2024-09-25T06:47:21.283028Z",
|
||||
|
@ -21,11 +26,11 @@ class OllamaMock < EndpointMock
|
|||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||
.with(body: request_body(prompt))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
.with(body: request_body(prompt, tool_call: tool_call))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||
end
|
||||
|
||||
def stream_line(delta)
|
||||
|
@ -71,14 +76,50 @@ class OllamaMock < EndpointMock
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||
.with(body: request_body(prompt, stream: true))
|
||||
.with(body: request_body(prompt))
|
||||
.to_return(status: 200, body: chunks)
|
||||
|
||||
yield if block_given?
|
||||
end
|
||||
|
||||
def request_body(prompt, stream: false)
|
||||
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json
|
||||
def tool_response
|
||||
{ function: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } } }
|
||||
end
|
||||
|
||||
def tool_payload
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "the city name",
|
||||
},
|
||||
unit: {
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
},
|
||||
required: %w[location unit],
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def request_body(prompt, tool_call: false)
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap do |b|
|
||||
b[:stream] = false
|
||||
b[:tools] = [tool_payload] if tool_call
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -100,6 +141,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
|||
compliance.regular_mode_simple_prompt(ollama_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(ollama_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
|
|
Loading…
Reference in New Issue