FEATURE: Add Ollama provider (#812)

This allows our users to add the Ollama provider and use it to serve our AI bot (completion/dialect).

In this PR, we introduce:

    DiscourseAi::Completions::Dialects::Ollama which would help us translate by utilizing Completions::Endpoint::Ollama
    Correct extract_completion_from and partials_from in Endpoints::Ollama

Also

    Add tests for Endpoints::Ollama
    Introduce ollama_model fabricator
This commit is contained in:
Hoa Nguyen 2024-10-01 10:45:03 +10:00 committed by GitHub
parent c7eaea48f5
commit 2063b3854f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 206 additions and 13 deletions

View File

@ -15,6 +15,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
DiscourseAi::Completions::Dialects::Ollama,
DiscourseAi::Completions::Dialects::OpenAiCompatible,
]
end

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Ollama < Dialect
class << self
def can_translate?(model_provider)
model_provider == "ollama"
end
end
# TODO: Add tool suppport
def max_prompt_tokens
llm_model.max_prompt_tokens
end
private
def tokenizer
llm_model.tokenizer_class
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def system_msg(msg)
{ role: "system", content: msg[:content] }
end
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }
# TODO: Add support for user messages with empbeded user ids
# TODO: Add support for user messages with attachments
user_message
end
end
end
end
end

View File

@ -41,7 +41,7 @@ module DiscourseAi
default_options
.merge(model_params)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
.tap { |payload| payload[:stream] = false if !@streaming_mode }
end
def prepare_request(payload)
@ -51,23 +51,14 @@ module DiscourseAi
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
decoded_chunk.split("\n").compact
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
parsed = JSON.parse(response_raw, symbolize_names: true)
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
parsed.dig(:message, :content)
end
end
end

View File

@ -79,3 +79,12 @@ Fabricator(:samba_nova_model, from: :llm_model) do
api_key "ABC"
url "https://api.sambanova.ai/v1/chat/completions"
end
Fabricator(:ollama_model, from: :llm_model) do
display_name "Ollama llama 3.1"
name "llama-3.1"
provider "ollama"
api_key "ABC"
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
url "http://api.ollama.ai/api/chat"
end

View File

@ -0,0 +1,36 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
fab!(:model) { Fabricate(:ollama_model) }
let(:context) { DialectContext.new(described_class, model) }
describe "#translate" do
it "translates a prompt written in our generic format to the Ollama format" do
ollama_version = [
{ role: "system", content: context.system_insts },
{ role: "user", content: context.simple_user_input },
]
translated = context.system_user_scenario
expect(translated).to eq(ollama_version)
end
it "trims content if it's getting too long" do
model.max_prompt_tokens = 5000
translated = context.long_user_input_scenario
expect(translated.last[:role]).to eq("user")
expect(translated.last[:content].length).to be < context.long_message_text.length
end
end
describe "#max_prompt_tokens" do
it "returns the max_prompt_tokens from the llm_model" do
model.max_prompt_tokens = 10_000
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
end
end
end

View File

@ -0,0 +1,112 @@
# frozen_string_literal: true
require_relative "endpoint_compliance"
class OllamaMock < EndpointMock
def response(content)
message_content = { content: content }
{
created_at: "2024-09-25T06:47:21.283028Z",
model: "llama3.1",
message: { role: "assistant" }.merge(message_content),
done: true,
done_reason: "stop",
total_duration: 7_639_718_541,
load_duration: 299_886_663,
prompt_eval_count: 18,
prompt_eval_duration: 220_447_000,
eval_count: 18,
eval_duration: 220_447_000,
}
end
def stub_response(prompt, response_text)
WebMock
.stub_request(:post, "http://api.ollama.ai/api/chat")
.with(body: request_body(prompt))
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta)
message_content = { content: delta }
+{
model: "llama3.1",
created_at: "2024-09-25T06:47:21.283028Z",
message: { role: "assistant" }.merge(message_content),
done: false,
}.to_json
end
def stub_raw(chunks)
WebMock.stub_request(:post, "http://api.ollama.ai/api/chat").to_return(
status: 200,
body: chunks,
)
end
def stub_streamed_response(prompt, deltas)
chunks = deltas.each_with_index.map { |_, index| stream_line(deltas[index]) }
chunks =
(
chunks.join("\n\n") << {
model: "llama3.1",
created_at: "2024-09-25T06:47:21.283028Z",
message: {
role: "assistant",
content: "",
},
done: true,
done_reason: "stop",
total_duration: 7_639_718_541,
load_duration: 299_886_663,
prompt_eval_count: 18,
prompt_eval_duration: 220_447_000,
eval_count: 18,
eval_duration: 220_447_000,
}.to_json
).split("")
WebMock
.stub_request(:post, "http://api.ollama.ai/api/chat")
.with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks)
yield if block_given?
end
def request_body(prompt, stream: false)
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
subject(:endpoint) { described_class.new(model) }
fab!(:user)
fab!(:model) { Fabricate(:ollama_model) }
let(:ollama_mock) { OllamaMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Ollama, user)
end
describe "#perform_completion!" do
context "when using regular mode" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(ollama_mock)
end
end
end
describe "when using streaming mode" do
context "with simpel prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(ollama_mock)
end
end
end
end