mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-06 17:30:20 +00:00
FEATURE: Add Ollama provider (#812)
This allows our users to add the Ollama provider and use it to serve our AI bot (completion/dialect). In this PR, we introduce: DiscourseAi::Completions::Dialects::Ollama which would help us translate by utilizing Completions::Endpoint::Ollama Correct extract_completion_from and partials_from in Endpoints::Ollama Also Add tests for Endpoints::Ollama Introduce ollama_model fabricator
This commit is contained in:
parent
c7eaea48f5
commit
2063b3854f
@ -15,6 +15,7 @@ module DiscourseAi
|
|||||||
DiscourseAi::Completions::Dialects::Gemini,
|
DiscourseAi::Completions::Dialects::Gemini,
|
||||||
DiscourseAi::Completions::Dialects::Claude,
|
DiscourseAi::Completions::Dialects::Claude,
|
||||||
DiscourseAi::Completions::Dialects::Command,
|
DiscourseAi::Completions::Dialects::Command,
|
||||||
|
DiscourseAi::Completions::Dialects::Ollama,
|
||||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
44
lib/completions/dialects/ollama.rb
Normal file
44
lib/completions/dialects/ollama.rb
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class Ollama < Dialect
|
||||||
|
class << self
|
||||||
|
def can_translate?(model_provider)
|
||||||
|
model_provider == "ollama"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# TODO: Add tool suppport
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
llm_model.max_prompt_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model.tokenizer_class
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_msg(msg)
|
||||||
|
{ role: "assistant", content: msg[:content] }
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_msg(msg)
|
||||||
|
{ role: "system", content: msg[:content] }
|
||||||
|
end
|
||||||
|
|
||||||
|
def user_msg(msg)
|
||||||
|
user_message = { role: "user", content: msg[:content] }
|
||||||
|
|
||||||
|
# TODO: Add support for user messages with empbeded user ids
|
||||||
|
# TODO: Add support for user messages with attachments
|
||||||
|
|
||||||
|
user_message
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -41,7 +41,7 @@ module DiscourseAi
|
|||||||
default_options
|
default_options
|
||||||
.merge(model_params)
|
.merge(model_params)
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
.tap { |payload| payload[:stream] = false if !@streaming_mode }
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
@ -51,23 +51,14 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
decoded_chunk
|
decoded_chunk.split("\n").compact
|
||||||
.split("\n")
|
|
||||||
.map do |line|
|
|
||||||
data = line.split("data: ", 2)[1]
|
|
||||||
data == "[DONE]" ? nil : data
|
|
||||||
end
|
|
||||||
.compact
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
# half a line sent here
|
|
||||||
return if !parsed
|
return if !parsed
|
||||||
|
|
||||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
parsed.dig(:message, :content)
|
||||||
|
|
||||||
response_h.dig(:content)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -79,3 +79,12 @@ Fabricator(:samba_nova_model, from: :llm_model) do
|
|||||||
api_key "ABC"
|
api_key "ABC"
|
||||||
url "https://api.sambanova.ai/v1/chat/completions"
|
url "https://api.sambanova.ai/v1/chat/completions"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Fabricator(:ollama_model, from: :llm_model) do
|
||||||
|
display_name "Ollama llama 3.1"
|
||||||
|
name "llama-3.1"
|
||||||
|
provider "ollama"
|
||||||
|
api_key "ABC"
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::Llama3Tokenizer"
|
||||||
|
url "http://api.ollama.ai/api/chat"
|
||||||
|
end
|
||||||
|
36
spec/lib/completions/dialects/ollama_spec.rb
Normal file
36
spec/lib/completions/dialects/ollama_spec.rb
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "dialect_context"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Dialects::Ollama do
|
||||||
|
fab!(:model) { Fabricate(:ollama_model) }
|
||||||
|
let(:context) { DialectContext.new(described_class, model) }
|
||||||
|
|
||||||
|
describe "#translate" do
|
||||||
|
it "translates a prompt written in our generic format to the Ollama format" do
|
||||||
|
ollama_version = [
|
||||||
|
{ role: "system", content: context.system_insts },
|
||||||
|
{ role: "user", content: context.simple_user_input },
|
||||||
|
]
|
||||||
|
|
||||||
|
translated = context.system_user_scenario
|
||||||
|
|
||||||
|
expect(translated).to eq(ollama_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
model.max_prompt_tokens = 5000
|
||||||
|
translated = context.long_user_input_scenario
|
||||||
|
|
||||||
|
expect(translated.last[:role]).to eq("user")
|
||||||
|
expect(translated.last[:content].length).to be < context.long_message_text.length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#max_prompt_tokens" do
|
||||||
|
it "returns the max_prompt_tokens from the llm_model" do
|
||||||
|
model.max_prompt_tokens = 10_000
|
||||||
|
expect(context.dialect(nil).max_prompt_tokens).to eq(10_000)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
112
spec/lib/completions/endpoints/ollama_spec.rb
Normal file
112
spec/lib/completions/endpoints/ollama_spec.rb
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "endpoint_compliance"
|
||||||
|
|
||||||
|
class OllamaMock < EndpointMock
|
||||||
|
def response(content)
|
||||||
|
message_content = { content: content }
|
||||||
|
|
||||||
|
{
|
||||||
|
created_at: "2024-09-25T06:47:21.283028Z",
|
||||||
|
model: "llama3.1",
|
||||||
|
message: { role: "assistant" }.merge(message_content),
|
||||||
|
done: true,
|
||||||
|
done_reason: "stop",
|
||||||
|
total_duration: 7_639_718_541,
|
||||||
|
load_duration: 299_886_663,
|
||||||
|
prompt_eval_count: 18,
|
||||||
|
prompt_eval_duration: 220_447_000,
|
||||||
|
eval_count: 18,
|
||||||
|
eval_duration: 220_447_000,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_response(prompt, response_text)
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||||
|
.with(body: request_body(prompt))
|
||||||
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
|
end
|
||||||
|
|
||||||
|
def stream_line(delta)
|
||||||
|
message_content = { content: delta }
|
||||||
|
|
||||||
|
+{
|
||||||
|
model: "llama3.1",
|
||||||
|
created_at: "2024-09-25T06:47:21.283028Z",
|
||||||
|
message: { role: "assistant" }.merge(message_content),
|
||||||
|
done: false,
|
||||||
|
}.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_raw(chunks)
|
||||||
|
WebMock.stub_request(:post, "http://api.ollama.ai/api/chat").to_return(
|
||||||
|
status: 200,
|
||||||
|
body: chunks,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_response(prompt, deltas)
|
||||||
|
chunks = deltas.each_with_index.map { |_, index| stream_line(deltas[index]) }
|
||||||
|
|
||||||
|
chunks =
|
||||||
|
(
|
||||||
|
chunks.join("\n\n") << {
|
||||||
|
model: "llama3.1",
|
||||||
|
created_at: "2024-09-25T06:47:21.283028Z",
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
},
|
||||||
|
done: true,
|
||||||
|
done_reason: "stop",
|
||||||
|
total_duration: 7_639_718_541,
|
||||||
|
load_duration: 299_886_663,
|
||||||
|
prompt_eval_count: 18,
|
||||||
|
prompt_eval_duration: 220_447_000,
|
||||||
|
eval_count: 18,
|
||||||
|
eval_duration: 220_447_000,
|
||||||
|
}.to_json
|
||||||
|
).split("")
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "http://api.ollama.ai/api/chat")
|
||||||
|
.with(body: request_body(prompt, stream: true))
|
||||||
|
.to_return(status: 200, body: chunks)
|
||||||
|
|
||||||
|
yield if block_given?
|
||||||
|
end
|
||||||
|
|
||||||
|
def request_body(prompt, stream: false)
|
||||||
|
model.default_options.merge(messages: prompt).tap { |b| b[:stream] = false if !stream }.to_json
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
||||||
|
subject(:endpoint) { described_class.new(model) }
|
||||||
|
|
||||||
|
fab!(:user)
|
||||||
|
fab!(:model) { Fabricate(:ollama_model) }
|
||||||
|
|
||||||
|
let(:ollama_mock) { OllamaMock.new(endpoint) }
|
||||||
|
|
||||||
|
let(:compliance) do
|
||||||
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Ollama, user)
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#perform_completion!" do
|
||||||
|
context "when using regular mode" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.regular_mode_simple_prompt(ollama_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "when using streaming mode" do
|
||||||
|
context "with simpel prompts" do
|
||||||
|
it "completes a trivial prompt and logs the response" do
|
||||||
|
compliance.streaming_mode_simple_prompt(ollama_mock)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
x
Reference in New Issue
Block a user