FEATURE: add a SambaNova LLM provider (#797)
Note, at the moment the context window is quite small, it is mainly useful as a helper backend or hyde generator
This commit is contained in:
parent
22d1e71dc9
commit
5b9add0ac8
|
@ -12,6 +12,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
Vllm = 5
|
Vllm = 5
|
||||||
Cohere = 6
|
Cohere = 6
|
||||||
Ollama = 7
|
Ollama = 7
|
||||||
|
SambaNova = 8
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -243,7 +243,7 @@ en:
|
||||||
confirm_delete: Are you sure you want to delete this model?
|
confirm_delete: Are you sure you want to delete this model?
|
||||||
delete: Delete
|
delete: Delete
|
||||||
seeded_warning: "This model is pre-configured on your site and cannot be edited."
|
seeded_warning: "This model is pre-configured on your site and cannot be edited."
|
||||||
in_use_warning:
|
in_use_warning:
|
||||||
one: "This model is currently used by the %{settings} setting. If misconfigured, the feature won't work as expected."
|
one: "This model is currently used by the %{settings} setting. If misconfigured, the feature won't work as expected."
|
||||||
other: "This model is currently used by the following settings: %{settings}. If misconfigured, features won't work as expected. "
|
other: "This model is currently used by the following settings: %{settings}. If misconfigured, features won't work as expected. "
|
||||||
|
|
||||||
|
@ -275,6 +275,7 @@ en:
|
||||||
azure: "Azure"
|
azure: "Azure"
|
||||||
ollama: "Ollama"
|
ollama: "Ollama"
|
||||||
CDCK: "CDCK"
|
CDCK: "CDCK"
|
||||||
|
samba_nova: "SambaNova"
|
||||||
|
|
||||||
provider_fields:
|
provider_fields:
|
||||||
access_key_id: "AWS Bedrock Access key ID"
|
access_key_id: "AWS Bedrock Access key ID"
|
||||||
|
|
|
@ -85,15 +85,14 @@ module DiscourseAi
|
||||||
encoded_uploads = prompt.encoded_uploads(message)
|
encoded_uploads = prompt.encoded_uploads(message)
|
||||||
return content if encoded_uploads.blank?
|
return content if encoded_uploads.blank?
|
||||||
|
|
||||||
content_w_imgs =
|
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
||||||
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
memo << {
|
||||||
memo << {
|
type: "image_url",
|
||||||
type: "image_url",
|
image_url: {
|
||||||
image_url: {
|
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
},
|
||||||
},
|
}
|
||||||
}
|
end
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,6 +17,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Endpoints::Vllm,
|
DiscourseAi::Completions::Endpoints::Vllm,
|
||||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||||
DiscourseAi::Completions::Endpoints::Cohere,
|
DiscourseAi::Completions::Endpoints::Cohere,
|
||||||
|
DiscourseAi::Completions::Endpoints::SambaNova,
|
||||||
]
|
]
|
||||||
|
|
||||||
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Endpoints
|
||||||
|
class SambaNova < Base
|
||||||
|
def self.can_contact?(model_provider)
|
||||||
|
model_provider == "samba_nova"
|
||||||
|
end
|
||||||
|
|
||||||
|
def normalize_model_params(model_params)
|
||||||
|
model_params = model_params.dup
|
||||||
|
|
||||||
|
# max_tokens, temperature are already supported
|
||||||
|
if model_params[:stop_sequences]
|
||||||
|
model_params[:stop] = model_params.delete(:stop_sequences)
|
||||||
|
end
|
||||||
|
|
||||||
|
model_params
|
||||||
|
end
|
||||||
|
|
||||||
|
def default_options
|
||||||
|
{ model: llm_model.name }
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_id
|
||||||
|
AiApiAuditLog::Provider::SambaNova
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def model_uri
|
||||||
|
URI(llm_model.url)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||||
|
|
||||||
|
payload[:stream] = true if @streaming_mode
|
||||||
|
|
||||||
|
payload
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_request(payload)
|
||||||
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
api_key = llm_model.api_key
|
||||||
|
|
||||||
|
headers["Authorization"] = "Bearer #{api_key}"
|
||||||
|
|
||||||
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
end
|
||||||
|
|
||||||
|
def final_log_update(log)
|
||||||
|
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||||
|
log.response_tokens = @completion_tokens if @completion_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def extract_completion_from(response_raw)
|
||||||
|
json = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
|
||||||
|
if @streaming_mode
|
||||||
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||||
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||||
|
end
|
||||||
|
|
||||||
|
parsed = json.dig(:choices, 0)
|
||||||
|
return if !parsed
|
||||||
|
|
||||||
|
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
|
||||||
|
end
|
||||||
|
|
||||||
|
def partials_from(decoded_chunk)
|
||||||
|
decoded_chunk
|
||||||
|
.split("\n")
|
||||||
|
.map do |line|
|
||||||
|
data = line.split("data: ", 2)[1]
|
||||||
|
data == "[DONE]" ? nil : data
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -76,7 +76,17 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_names
|
def provider_names
|
||||||
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
providers = %w[
|
||||||
|
aws_bedrock
|
||||||
|
anthropic
|
||||||
|
vllm
|
||||||
|
hugging_face
|
||||||
|
cohere
|
||||||
|
open_ai
|
||||||
|
google
|
||||||
|
azure
|
||||||
|
samba_nova
|
||||||
|
]
|
||||||
if !Rails.env.production?
|
if !Rails.env.production?
|
||||||
providers << "fake"
|
providers << "fake"
|
||||||
providers << "ollama"
|
providers << "ollama"
|
||||||
|
|
|
@ -71,3 +71,11 @@ Fabricator(:cohere_model, from: :llm_model) do
|
||||||
api_key "ABC"
|
api_key "ABC"
|
||||||
url "https://api.cohere.ai/v1/chat"
|
url "https://api.cohere.ai/v1/chat"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Fabricator(:samba_nova_model, from: :llm_model) do
|
||||||
|
display_name "Samba Nova"
|
||||||
|
name "samba-nova"
|
||||||
|
provider "samba_nova"
|
||||||
|
api_key "ABC"
|
||||||
|
url "https://api.sambanova.ai/v1/chat/completions"
|
||||||
|
end
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Endpoints::SambaNova do
|
||||||
|
fab!(:llm_model) { Fabricate(:samba_nova_model) }
|
||||||
|
let(:llm) { llm_model.to_llm }
|
||||||
|
|
||||||
|
it "can stream completions" do
|
||||||
|
body = <<~PARTS
|
||||||
|
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [{"index": 0, "delta": {"content": "I am a bot"}, "logprobs": null, "finish_reason": null}]}
|
||||||
|
|
||||||
|
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [], "usage": {"is_last_response": true, "total_tokens": 62, "prompt_tokens": 21, "completion_tokens": 41, "time_to_first_token": 0.09152531623840332, "end_time": 1721336361.582011, "start_time": 1721336361.413994, "total_latency": 0.16801691055297852, "total_tokens_per_sec": 369.010475171488, "completion_tokens_per_sec": 244.02305616179046, "completion_tokens_after_first_per_sec": 522.9332759819093, "completion_tokens_after_first_per_sec_first_ten": 1016.0004844667837}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
PARTS
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
|
||||||
|
body:
|
||||||
|
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}],\"stream\":true}",
|
||||||
|
headers: {
|
||||||
|
"Authorization" => "Bearer ABC",
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: body, headers: {})
|
||||||
|
|
||||||
|
response = +""
|
||||||
|
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
|
||||||
|
|
||||||
|
expect(response).to eq("I am a bot")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "can perform regular completions" do
|
||||||
|
body = { choices: [message: { content: "I am a bot" }] }.to_json
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
|
||||||
|
body:
|
||||||
|
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}]}",
|
||||||
|
headers: {
|
||||||
|
"Authorization" => "Bearer ABC",
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: body, headers: {})
|
||||||
|
|
||||||
|
response = llm.generate("who are you?", user: Discourse.system_user)
|
||||||
|
|
||||||
|
expect(response).to eq("I am a bot")
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue