FEATURE: add a SambaNova LLM provider (#797)

Note, at the moment the context window is quite small, it is
mainly useful as a helper backend or hyde generator
This commit is contained in:
Sam 2024-09-12 11:28:08 +10:00 committed by GitHub
parent 22d1e71dc9
commit 5b9add0ac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 162 additions and 11 deletions

View File

@ -12,6 +12,7 @@ class AiApiAuditLog < ActiveRecord::Base
Vllm = 5
Cohere = 6
Ollama = 7
SambaNova = 8
end
end

View File

@ -275,6 +275,7 @@ en:
azure: "Azure"
ollama: "Ollama"
CDCK: "CDCK"
samba_nova: "SambaNova"
provider_fields:
access_key_id: "AWS Bedrock Access key ID"

View File

@ -85,15 +85,14 @@ module DiscourseAi
encoded_uploads = prompt.encoded_uploads(message)
return content if encoded_uploads.blank?
content_w_imgs =
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
memo << {
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
memo << {
type: "image_url",
image_url: {
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
},
}
end
end
end
end

View File

@ -17,6 +17,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Vllm,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::Cohere,
DiscourseAi::Completions::Endpoints::SambaNova,
]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?

View File

@ -0,0 +1,84 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class SambaNova < Base
def self.can_contact?(model_provider)
model_provider == "samba_nova"
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ model: llm_model.name }
end
def provider_id
AiApiAuditLog::Provider::SambaNova
end
private
def model_uri
URI(llm_model.url)
end
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
payload[:stream] = true if @streaming_mode
payload
end
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
api_key = llm_model.api_key
headers["Authorization"] = "Bearer #{api_key}"
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def final_log_update(log)
log.request_tokens = @prompt_tokens if @prompt_tokens
log.response_tokens = @completion_tokens if @completion_tokens
end
def extract_completion_from(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end
parsed = json.dig(:choices, 0)
return if !parsed
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
end
end
end
end
end

View File

@ -76,7 +76,17 @@ module DiscourseAi
end
def provider_names
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
providers = %w[
aws_bedrock
anthropic
vllm
hugging_face
cohere
open_ai
google
azure
samba_nova
]
if !Rails.env.production?
providers << "fake"
providers << "ollama"

View File

@ -71,3 +71,11 @@ Fabricator(:cohere_model, from: :llm_model) do
api_key "ABC"
url "https://api.cohere.ai/v1/chat"
end
Fabricator(:samba_nova_model, from: :llm_model) do
display_name "Samba Nova"
name "samba-nova"
provider "samba_nova"
api_key "ABC"
url "https://api.sambanova.ai/v1/chat/completions"
end

View File

@ -0,0 +1,47 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Endpoints::SambaNova do
fab!(:llm_model) { Fabricate(:samba_nova_model) }
let(:llm) { llm_model.to_llm }
it "can stream completions" do
body = <<~PARTS
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [{"index": 0, "delta": {"content": "I am a bot"}, "logprobs": null, "finish_reason": null}]}
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [], "usage": {"is_last_response": true, "total_tokens": 62, "prompt_tokens": 21, "completion_tokens": 41, "time_to_first_token": 0.09152531623840332, "end_time": 1721336361.582011, "start_time": 1721336361.413994, "total_latency": 0.16801691055297852, "total_tokens_per_sec": 369.010475171488, "completion_tokens_per_sec": 244.02305616179046, "completion_tokens_after_first_per_sec": 522.9332759819093, "completion_tokens_after_first_per_sec_first_ten": 1016.0004844667837}}
data: [DONE]
PARTS
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
body:
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}],\"stream\":true}",
headers: {
"Authorization" => "Bearer ABC",
"Content-Type" => "application/json",
},
).to_return(status: 200, body: body, headers: {})
response = +""
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
expect(response).to eq("I am a bot")
end
it "can perform regular completions" do
body = { choices: [message: { content: "I am a bot" }] }.to_json
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
body:
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}]}",
headers: {
"Authorization" => "Bearer ABC",
"Content-Type" => "application/json",
},
).to_return(status: 200, body: body, headers: {})
response = llm.generate("who are you?", user: Discourse.system_user)
expect(response).to eq("I am a bot")
end
end