FEATURE: Add Ollama provider (#812)
This allows our users to add the Ollama provider and use it to serve our AI bot (completion/dialect). In this PR, we introduce: DiscourseAi::Completions::Dialects::Ollama which would help us translate by utilizing Completions::Endpoint::Ollama Correct extract_completion_from and partials_from in Endpoints::Ollama Also Add tests for Endpoints::Ollama Introduce ollama_model fabricator
This commit is contained in:
parent
c7eaea48f5
commit
2063b3854f
|
@ -15,6 +15,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Command,
|
||||
DiscourseAi::Completions::Dialects::Ollama,
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||
]
|
||||
end
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Ollama < Dialect
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "ollama"
|
||||
end
|
||||
end
|
||||
|
||||
# TODO: Add tool suppport
|
||||
|
||||
def max_prompt_tokens
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def tokenizer
|
||||
llm_model.tokenizer_class
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
{ role: "system", content: msg[:content] }
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
user_message = { role: "user", content: msg[:content] }
|
||||
|
||||
# TODO: Add support for user messages with empbeded user ids
|
||||
# TODO: Add support for user messages with attachments
|
||||
|
||||
user_message
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -41,7 +41,7 @@ module DiscourseAi
|
|||
default_options
|
||||
.merge(model_params)
|
||||
.merge(messages: prompt)
|
||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
||||
.tap { |payload| payload[:stream] = false if !@streaming_mode }
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -51,23 +51,14 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data == "[DONE]" ? nil : data
|
||||
end
|
||||
.compact
|
||||
decoded_chunk.split("\n").compact
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||
# half a line sent here
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
return if !parsed
|
||||
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
|
||||
response_h.dig(:content)
|
||||
parsed.dig(:message, :content)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -79,3 +79,12 @@ Fabricator(:samba_nova_model, from: :llm_model) do
|
|||
api_key "ABC"
|
||||
url "https://api.sambanova.ai/v1/chat/completions"
|
||||
end
|
||||
|
||||
Fabricator(:ollama_model, from: :llm_model) do
|
||||
display_name "Ollama llama 3.1"
|
||||
name "llama-3.1"
|
||||
provider "ollama"
|
||||
api_key "ABC"
|
||||
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||
url "http://api.ollama.ai/api/chat"
|
||||
end
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
||||
fab!(:model) { Fabricate(:ollama_model) }
|
||||
let(:context) { DialectContext.new(described_class, model) }
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the Ollama format" do
|
||||
ollama_version = [
|
||||
{ role: "system", content: context.system_insts },
|
||||
{ role: "user", content: context.simple_user_input },
|
||||
]
|
||||
|
||||
translated = context.system_user_scenario
|
||||
|
||||
expect(translated).to eq(ollama_version)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
model.max_prompt_tokens = 5000
|
||||
translated = context.long_user_input_scenario
|
||||
|
||||
expect(translated.last[:role]).to eq("user")
|
||||
expect(translated.last[:content].length).to be < context.long_message_text.length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#max_prompt_tokens" do
|
||||
it "returns the max_prompt_tokens from the llm_model" do
|
||||
model.max_prompt_tokens = 10_000
|
||||
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,112 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_compliance"
|
||||
|
||||
class OllamaMock < EndpointMock
|
||||
def response(content)
|
||||
message_content = { content: content }
|
||||
|
||||
{
|
||||
created_at: "2024-09-25T06:47:21.283028Z",
|
||||
model: "llama3.1",
|
||||
message: { role: "assistant" }.merge(message_content),
|
||||
done: true,
|
||||
done_reason: "stop",
|
||||
total_duration: 7_639_718_541,
|
||||
load_duration: 299_886_663,
|
||||
prompt_eval_count: 18,
|
||||
prompt_eval_duration: 220_447_000,
|
||||
eval_count: 18,
|
||||
eval_duration: 220_447_000,
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
WebMock
|
||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||
.with(body: request_body(prompt))
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
def stream_line(delta)
|
||||
message_content = { content: delta }
|
||||
|
||||
+{
|
||||
model: "llama3.1",
|
||||
created_at: "2024-09-25T06:47:21.283028Z",
|
||||
message: { role: "assistant" }.merge(message_content),
|
||||
done: false,
|
||||
}.to_json
|
||||
end
|
||||
|
||||
def stub_raw(chunks)
|
||||
WebMock.stub_request(:post, "http://api.ollama.ai/api/chat").to_return(
|
||||
status: 200,
|
||||
body: chunks,
|
||||
)
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
chunks = deltas.each_with_index.map { |_, index| stream_line(deltas[index]) }
|
||||
|
||||
chunks =
|
||||
(
|
||||
chunks.join("\n\n") << {
|
||||
model: "llama3.1",
|
||||
created_at: "2024-09-25T06:47:21.283028Z",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
},
|
||||
done: true,
|
||||
done_reason: "stop",
|
||||
total_duration: 7_639_718_541,
|
||||
load_duration: 299_886_663,
|
||||
prompt_eval_count: 18,
|
||||
prompt_eval_duration: 220_447_000,
|
||||
eval_count: 18,
|
||||
eval_duration: 220_447_000,
|
||||
}.to_json
|
||||
).split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||
.with(body: request_body(prompt, stream: true))
|
||||
.to_return(status: 200, body: chunks)
|
||||
|
||||
yield if block_given?
|
||||
end
|
||||
|
||||
def request_body(prompt, stream: false)
|
||||
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
||||
subject(:endpoint) { described_class.new(model) }
|
||||
|
||||
fab!(:user)
|
||||
fab!(:model) { Fabricate(:ollama_model) }
|
||||
|
||||
let(:ollama_mock) { OllamaMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Ollama, user)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(ollama_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simpel prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(ollama_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue