diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index c411fb14..79e065a6 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -31,6 +31,7 @@ class LlmModel < ActiveRecord::Base }, ollama: { disable_system_prompt: :checkbox, + enable_native_tool: :checkbox, }, } end diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 9403dc0e..ea83abb9 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -312,6 +312,7 @@ en: region: "AWS Bedrock Region" organization: "Optional OpenAI Organization ID" disable_system_prompt: "Disable system message in prompts" + enable_native_tool: "Enable native tool support" related_topics: title: "Related Topics" diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 5a32f0c3..541d0e73 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -10,7 +10,9 @@ module DiscourseAi end end - # TODO: Add tool suppport + def native_tool_support? + enable_native_tool? + end def max_prompt_tokens llm_model.max_prompt_tokens @@ -18,6 +20,14 @@ module DiscourseAi private + def tools_dialect + if enable_native_tool? + @tools_dialect ||= DiscourseAi::Completions::Dialects::OllamaTools.new(prompt.tools) + else + super + end + end + def tokenizer llm_model.tokenizer_class end @@ -26,8 +36,28 @@ module DiscourseAi { role: "assistant", content: msg[:content] } end + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + def system_msg(msg) - { role: "system", content: msg[:content] } + msg = { role: "system", content: msg[:content] } + + if tools_dialect.instructions.present? + msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}" + end + + msg + end + + def enable_native_tool? + return @enable_native_tool if defined?(@enable_native_tool) + + @enable_native_tool = llm_model.lookup_custom_param("enable_native_tool") end def user_msg(msg) diff --git a/lib/completions/dialects/ollama_tools.rb b/lib/completions/dialects/ollama_tools.rb new file mode 100644 index 00000000..6f4d669d --- /dev/null +++ b/lib/completions/dialects/ollama_tools.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + # TODO: Define the Tool class to be inherited by all tools. + class OllamaTools + def initialize(tools) + @raw_tools = tools + end + + def instructions + "" # Noop. Tools are listed separate. + end + + def translated_tools + raw_tools.map do |t| + tool = t.dup + + tool[:parameters] = t[:parameters] + .to_a + .reduce({ type: "object", properties: {}, required: [] }) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] + + except = %i[name required item_type] + except << :enum if p[:enum].blank? + + memo[:properties][name] = p.except(*except) + memo + end + + { type: "function", function: tool } + end + end + + def from_raw_tool_call(raw_message) + call_details = JSON.parse(raw_message[:content], symbolize_names: true) + call_details[:name] = raw_message[:name] + + { + role: "assistant", + content: nil, + tool_calls: [{ type: "function", function: call_details }], + } + end + + def from_raw_tool(raw_message) + { role: "tool", content: raw_message[:content], name: raw_message[:name] } + end + + private + + attr_reader :raw_tools + end + end + end +end diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index 4a8453db..cc58006a 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -37,11 +37,28 @@ module DiscourseAi URI(llm_model.url) end - def prepare_payload(prompt, model_params, _dialect) + def native_tool_support? + @native_tool_support + end + + def has_tool?(_response_data) + @has_function_call + end + + def prepare_payload(prompt, model_params, dialect) + @native_tool_support = dialect.native_tool_support? + + # https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1 + # Due to ollama enforce a 'stream: false' for tool calls, instead of complicating the code, + # we will just disable streaming for all ollama calls if native tool support is enabled + default_options .merge(model_params) .merge(messages: prompt) - .tap { |payload| payload[:stream] = false if !@streaming_mode } + .tap { |payload| payload[:stream] = false if @native_tool_support || !@streaming_mode } + .tap do |payload| + payload[:tools] = dialect.tools if @native_tool_support && dialect.tools.present? + end end def prepare_request(payload) @@ -58,7 +75,66 @@ module DiscourseAi parsed = JSON.parse(response_raw, symbolize_names: true) return if !parsed - parsed.dig(:message, :content) + response_h = parsed.dig(:message) + + @has_function_call ||= response_h.dig(:tool_calls).present? + @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) + end + + def add_to_function_buffer(function_buffer, payload: nil, partial: nil) + @args_buffer ||= +"" + + if @streaming_mode + return function_buffer if !partial + else + partial = payload + end + + f_name = partial.dig(:function, :name) + + @current_function ||= function_buffer.at("invoke") + + if f_name + current_name = function_buffer.at("tool_name").content + + if current_name.blank? + # first call + else + # we have a previous function, so we need to add a noop + @args_buffer = +"" + @current_function = + function_buffer.at("function_calls").add_child( + Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), + ) + end + end + + @current_function.at("tool_name").content = f_name if f_name + @current_function.at("tool_id").content = partial[:id] if partial[:id] + + args = partial.dig(:function, :arguments) + + # allow for SPACE within arguments + if args && args != "" + @args_buffer << args.to_json + + begin + json_args = JSON.parse(@args_buffer, symbolize_names: true) + + argument_fragments = + json_args.reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + @current_function.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + rescue JSON::ParserError + return function_buffer + end + end + + function_buffer end end end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb index c380620c..64f17725 100644 --- a/spec/fabricators/llm_model_fabricator.rb +++ b/spec/fabricators/llm_model_fabricator.rb @@ -87,4 +87,5 @@ Fabricator(:ollama_model, from: :llm_model) do api_key "ABC" tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer" url "http://api.ollama.ai/api/chat" + provider_params { { enable_native_tool: true } } end diff --git a/spec/lib/completions/dialects/ollama_spec.rb b/spec/lib/completions/dialects/ollama_spec.rb index d8f1b250..a8ee35fe 100644 --- a/spec/lib/completions/dialects/ollama_spec.rb +++ b/spec/lib/completions/dialects/ollama_spec.rb @@ -7,15 +7,37 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do let(:context) { DialectContext.new(described_class, model) } describe "#translate" do - it "translates a prompt written in our generic format to the Ollama format" do - ollama_version = [ - { role: "system", content: context.system_insts }, - { role: "user", content: context.simple_user_input }, - ] + context "when native tool support is enabled" do + it "translates a prompt written in our generic format to the Ollama format" do + ollama_version = [ + { role: "system", content: context.system_insts }, + { role: "user", content: context.simple_user_input }, + ] - translated = context.system_user_scenario + translated = context.system_user_scenario - expect(translated).to eq(ollama_version) + expect(translated).to eq(ollama_version) + end + end + + context "when native tool support is disabled - XML tools" do + it "includes the instructions in the system message" do + allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false) + + DiscourseAi::Completions::Dialects::XmlTools + .any_instance + .stubs(:instructions) + .returns("Instructions") + + ollama_version = [ + { role: "system", content: "#{context.system_insts}\n\nInstructions" }, + { role: "user", content: context.simple_user_input }, + ] + + translated = context.system_user_scenario + + expect(translated).to eq(ollama_version) + end end it "trims content if it's getting too long" do @@ -33,4 +55,40 @@ RSpec.describe DiscourseAi::Completions::Dialects::Ollama do expect(context.dialect(nil).max_prompt_tokens).to eq(10_000) end end + + describe "#tools" do + context "when native tools are enabled" do + it "returns the translated tools from the OllamaTools class" do + tool = instance_double(DiscourseAi::Completions::Dialects::OllamaTools) + + allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(true) + allow(tool).to receive(:translated_tools) + allow(DiscourseAi::Completions::Dialects::OllamaTools).to receive(:new).and_return(tool) + + context.dialect_tools + + expect(DiscourseAi::Completions::Dialects::OllamaTools).to have_received(:new).with( + context.prompt.tools, + ) + expect(tool).to have_received(:translated_tools) + end + end + + context "when native tools are disabled" do + it "returns the translated tools from the XmlTools class" do + tool = instance_double(DiscourseAi::Completions::Dialects::XmlTools) + + allow(model).to receive(:lookup_custom_param).with("enable_native_tool").and_return(false) + allow(tool).to receive(:translated_tools) + allow(DiscourseAi::Completions::Dialects::XmlTools).to receive(:new).and_return(tool) + + context.dialect_tools + + expect(DiscourseAi::Completions::Dialects::XmlTools).to have_received(:new).with( + context.prompt.tools, + ) + expect(tool).to have_received(:translated_tools) + end + end + end end diff --git a/spec/lib/completions/dialects/ollama_tools_spec.rb b/spec/lib/completions/dialects/ollama_tools_spec.rb new file mode 100644 index 00000000..a1f9f235 --- /dev/null +++ b/spec/lib/completions/dialects/ollama_tools_spec.rb @@ -0,0 +1,112 @@ +# frozen_string_literal: true + +require_relative "dialect_context" + +RSpec.describe DiscourseAi::Completions::Dialects::OllamaTools do + describe "#translated_tools" do + it "translates a tool from our generic format to the Ollama format" do + tools = [ + { + name: "github_file_content", + description: "Retrieves the content of specified GitHub files", + parameters: [ + { + name: "repo_name", + description: "The name of the GitHub repository (e.g., 'discourse/discourse')", + type: "string", + required: true, + }, + { + name: "file_paths", + description: "The paths of the files to retrieve within the repository", + type: "array", + item_type: "string", + required: true, + }, + { + name: "branch", + description: "The branch or commit SHA to retrieve the files from (default: 'main')", + type: "string", + required: false, + }, + ], + }, + ] + + ollama_tools = described_class.new(tools) + + translated_tools = ollama_tools.translated_tools + + expect(translated_tools).to eq( + [ + { + type: "function", + function: { + name: "github_file_content", + description: "Retrieves the content of specified GitHub files", + parameters: { + type: "object", + properties: { + "repo_name" => { + description: "The name of the GitHub repository (e.g., 'discourse/discourse')", + type: "string", + }, + "file_paths" => { + description: "The paths of the files to retrieve within the repository", + type: "array", + }, + "branch" => { + description: + "The branch or commit SHA to retrieve the files from (default: 'main')", + type: "string", + }, + }, + required: %w[repo_name file_paths], + }, + }, + }, + ], + ) + end + end + + describe "#from_raw_tool_call" do + it "converts a raw tool call to the Ollama tool format" do + raw_message = { + content: '{"repo_name":"discourse/discourse","file_paths":["README.md"],"branch":"main"}', + } + + ollama_tools = described_class.new([]) + tool_call = ollama_tools.from_raw_tool_call(raw_message) + + expect(tool_call).to eq( + { + role: "assistant", + content: nil, + tool_calls: [ + { + type: "function", + function: { + repo_name: "discourse/discourse", + file_paths: ["README.md"], + branch: "main", + name: nil, + }, + }, + ], + }, + ) + end + end + + describe "#from_raw_tool" do + it "converts a raw tool to the Ollama tool format" do + raw_message = { content: "Hello, world!", name: "github_file_content" } + + ollama_tools = described_class.new([]) + tool = ollama_tools.from_raw_tool(raw_message) + + expect(tool).to eq({ role: "tool", content: "Hello, world!", name: "github_file_content" }) + end + end +end diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb index 9dc99a04..eb6bc63c 100644 --- a/spec/lib/completions/endpoints/ollama_spec.rb +++ b/spec/lib/completions/endpoints/ollama_spec.rb @@ -3,8 +3,13 @@ require_relative "endpoint_compliance" class OllamaMock < EndpointMock - def response(content) - message_content = { content: content } + def response(content, tool_call: false) + message_content = + if tool_call + { content: "", tool_calls: [content] } + else + { content: content } + end { created_at: "2024-09-25T06:47:21.283028Z", @@ -21,11 +26,11 @@ class OllamaMock < EndpointMock } end - def stub_response(prompt, response_text) + def stub_response(prompt, response_text, tool_call: false) WebMock .stub_request(:post, "http://api.ollama.ai/api/chat") - .with(body: request_body(prompt)) - .to_return(status: 200, body: JSON.dump(response(response_text))) + .with(body: request_body(prompt, tool_call: tool_call)) + .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call))) end def stream_line(delta) @@ -71,14 +76,50 @@ class OllamaMock < EndpointMock WebMock .stub_request(:post, "http://api.ollama.ai/api/chat") - .with(body: request_body(prompt, stream: true)) + .with(body: request_body(prompt)) .to_return(status: 200, body: chunks) yield if block_given? end - def request_body(prompt, stream: false) - model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json + def tool_response + { function: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } } } + end + + def tool_payload + { + type: "function", + function: { + name: "get_weather", + description: "Get the weather in a city", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "the city name", + }, + unit: { + type: "string", + description: "the unit of measurement celcius c or fahrenheit f", + enum: %w[c f], + }, + }, + required: %w[location unit], + }, + }, + } + end + + def request_body(prompt, tool_call: false) + model + .default_options + .merge(messages: prompt) + .tap do |b| + b[:stream] = false + b[:tools] = [tool_payload] if tool_call + end + .to_json end end @@ -100,6 +141,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do compliance.regular_mode_simple_prompt(ollama_mock) end end + + context "with tools" do + it "returns a function invocation" do + compliance.regular_mode_tools(ollama_mock) + end + end end describe "when using streaming mode" do