From 2063b3854ffc1b38b061962a06868dc608a6c849 Mon Sep 17 00:00:00 2001 From: Hoa Nguyen Date: Tue, 1 Oct 2024 10:45:03 +1000 Subject: [PATCH] FEATURE: Add Ollama provider (#812) This allows our users to add the Ollama provider and use it to serve our AI bot (completion/dialect). In this PR, we introduce: DiscourseAi::Completions::Dialects::Ollama which would help us translate by utilizing Completions::Endpoint::Ollama Correct extract_completion_from and partials_from in Endpoints::Ollama Also Add tests for Endpoints::Ollama Introduce ollama_model fabricator --- lib/completions/dialects/dialect.rb | 1 + lib/completions/dialects/ollama.rb | 44 +++++++ lib/completions/endpoints/ollama.rb | 17 +-- spec/fabricators/llm_model_fabricator.rb | 9 ++ spec/lib/completions/dialects/ollama_spec.rb | 36 ++++++ spec/lib/completions/endpoints/ollama_spec.rb | 112 ++++++++++++++++++ 6 files changed, 206 insertions(+), 13 deletions(-) create mode 100644 lib/completions/dialects/ollama.rb create mode 100644 spec/lib/completions/dialects/ollama_spec.rb create mode 100644 spec/lib/completions/endpoints/ollama_spec.rb diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 2a4a489e..5420e643 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -15,6 +15,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, + DiscourseAi::Completions::Dialects::Ollama, DiscourseAi::Completions::Dialects::OpenAiCompatible, ] end diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb new file mode 100644 index 00000000..5a32f0c3 --- /dev/null +++ b/lib/completions/dialects/ollama.rb @@ -0,0 +1,44 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class Ollama < Dialect + class << self + def can_translate?(model_provider) + model_provider == "ollama" + end + end + + # TODO: Add tool suppport + + def max_prompt_tokens + llm_model.max_prompt_tokens + end + + private + + def tokenizer + llm_model.tokenizer_class + end + + def model_msg(msg) + { role: "assistant", content: msg[:content] } + end + + def system_msg(msg) + { role: "system", content: msg[:content] } + end + + def user_msg(msg) + user_message = { role: "user", content: msg[:content] } + + # TODO: Add support for user messages with empbeded user ids + # TODO: Add support for user messages with attachments + + user_message + end + end + end + end +end diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index dc55701a..4a8453db 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -41,7 +41,7 @@ module DiscourseAi default_options .merge(model_params) .merge(messages: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + .tap { |payload| payload[:stream] = false if !@streaming_mode } end def prepare_request(payload) @@ -51,23 +51,14 @@ module DiscourseAi end def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end - .compact + decoded_chunk.split("\n").compact end def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here + parsed = JSON.parse(response_raw, symbolize_names: true) return if !parsed - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - - response_h.dig(:content) + parsed.dig(:message, :content) end end end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb index bea2d72c..c380620c 100644 --- a/spec/fabricators/llm_model_fabricator.rb +++ b/spec/fabricators/llm_model_fabricator.rb @@ -79,3 +79,12 @@ Fabricator(:samba_nova_model, from: :llm_model) do api_key "ABC" url "https://api.sambanova.ai/v1/chat/completions" end + +Fabricator(:ollama_model, from: :llm_model) do + display_name "Ollama llama 3.1" + name "llama-3.1" + provider "ollama" + api_key "ABC" + tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer" + url "http://api.ollama.ai/api/chat" +end diff --git a/spec/lib/completions/dialects/ollama_spec.rb b/spec/lib/completions/dialects/ollama_spec.rb new file mode 100644 index 00000000..d8f1b250 --- /dev/null +++ b/spec/lib/completions/dialects/ollama_spec.rb @@ -0,0 +1,36 @@ +# frozen_string_literal: true + +require_relative "dialect_context" + +RSpec.describe DiscourseAi::Completions::Dialects::Ollama do + fab!(:model) { Fabricate(:ollama_model) } + let(:context) { DialectContext.new(described_class, model) } + + describe "#translate" do + it "translates a prompt written in our generic format to the Ollama format" do + ollama_version = [ + { role: "system", content: context.system_insts }, + { role: "user", content: context.simple_user_input }, + ] + + translated = context.system_user_scenario + + expect(translated).to eq(ollama_version) + end + + it "trims content if it's getting too long" do + model.max_prompt_tokens = 5000 + translated = context.long_user_input_scenario + + expect(translated.last[:role]).to eq("user") + expect(translated.last[:content].length).to be < context.long_message_text.length + end + end + + describe "#max_prompt_tokens" do + it "returns the max_prompt_tokens from the llm_model" do + model.max_prompt_tokens = 10_000 + expect(context.dialect(nil).max_prompt_tokens).to eq(10_000) + end + end +end diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb new file mode 100644 index 00000000..9dc99a04 --- /dev/null +++ b/spec/lib/completions/endpoints/ollama_spec.rb @@ -0,0 +1,112 @@ +# frozen_string_literal: true + +require_relative "endpoint_compliance" + +class OllamaMock < EndpointMock + def response(content) + message_content = { content: content } + + { + created_at: "2024-09-25T06:47:21.283028Z", + model: "llama3.1", + message: { role: "assistant" }.merge(message_content), + done: true, + done_reason: "stop", + total_duration: 7_639_718_541, + load_duration: 299_886_663, + prompt_eval_count: 18, + prompt_eval_duration: 220_447_000, + eval_count: 18, + eval_duration: 220_447_000, + } + end + + def stub_response(prompt, response_text) + WebMock + .stub_request(:post, "http://api.ollama.ai/api/chat") + .with(body: request_body(prompt)) + .to_return(status: 200, body: JSON.dump(response(response_text))) + end + + def stream_line(delta) + message_content = { content: delta } + + +{ + model: "llama3.1", + created_at: "2024-09-25T06:47:21.283028Z", + message: { role: "assistant" }.merge(message_content), + done: false, + }.to_json + end + + def stub_raw(chunks) + WebMock.stub_request(:post, "http://api.ollama.ai/api/chat").to_return( + status: 200, + body: chunks, + ) + end + + def stub_streamed_response(prompt, deltas) + chunks = deltas.each_with_index.map { |_, index| stream_line(deltas[index]) } + + chunks = + ( + chunks.join("\n\n") << { + model: "llama3.1", + created_at: "2024-09-25T06:47:21.283028Z", + message: { + role: "assistant", + content: "", + }, + done: true, + done_reason: "stop", + total_duration: 7_639_718_541, + load_duration: 299_886_663, + prompt_eval_count: 18, + prompt_eval_duration: 220_447_000, + eval_count: 18, + eval_duration: 220_447_000, + }.to_json + ).split("") + + WebMock + .stub_request(:post, "http://api.ollama.ai/api/chat") + .with(body: request_body(prompt, stream: true)) + .to_return(status: 200, body: chunks) + + yield if block_given? + end + + def request_body(prompt, stream: false) + model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json + end +end + +RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do + subject(:endpoint) { described_class.new(model) } + + fab!(:user) + fab!(:model) { Fabricate(:ollama_model) } + + let(:ollama_mock) { OllamaMock.new(endpoint) } + + let(:compliance) do + EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Ollama, user) + end + + describe "#perform_completion!" do + context "when using regular mode" do + it "completes a trivial prompt and logs the response" do + compliance.regular_mode_simple_prompt(ollama_mock) + end + end + end + + describe "when using streaming mode" do + context "with simpel prompts" do + it "completes a trivial prompt and logs the response" do + compliance.streaming_mode_simple_prompt(ollama_mock) + end + end + end +end