FEATURE: add a SambaNova LLM provider (#797)
Note, at the moment the context window is quite small, it is mainly useful as a helper backend or hyde generator
This commit is contained in:
parent
22d1e71dc9
commit
5b9add0ac8
|
@ -12,6 +12,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
|||
Vllm = 5
|
||||
Cohere = 6
|
||||
Ollama = 7
|
||||
SambaNova = 8
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -243,7 +243,7 @@ en:
|
|||
confirm_delete: Are you sure you want to delete this model?
|
||||
delete: Delete
|
||||
seeded_warning: "This model is pre-configured on your site and cannot be edited."
|
||||
in_use_warning:
|
||||
in_use_warning:
|
||||
one: "This model is currently used by the %{settings} setting. If misconfigured, the feature won't work as expected."
|
||||
other: "This model is currently used by the following settings: %{settings}. If misconfigured, features won't work as expected. "
|
||||
|
||||
|
@ -275,6 +275,7 @@ en:
|
|||
azure: "Azure"
|
||||
ollama: "Ollama"
|
||||
CDCK: "CDCK"
|
||||
samba_nova: "SambaNova"
|
||||
|
||||
provider_fields:
|
||||
access_key_id: "AWS Bedrock Access key ID"
|
||||
|
|
|
@ -85,15 +85,14 @@ module DiscourseAi
|
|||
encoded_uploads = prompt.encoded_uploads(message)
|
||||
return content if encoded_uploads.blank?
|
||||
|
||||
content_w_imgs =
|
||||
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
||||
memo << {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||
},
|
||||
}
|
||||
end
|
||||
encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details|
|
||||
memo << {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: "data:#{details[:mime_type]};base64,#{details[:base64]}",
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -17,6 +17,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Endpoints::Vllm,
|
||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||
DiscourseAi::Completions::Endpoints::Cohere,
|
||||
DiscourseAi::Completions::Endpoints::SambaNova,
|
||||
]
|
||||
|
||||
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Endpoints
|
||||
class SambaNova < Base
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "samba_nova"
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
model_params = model_params.dup
|
||||
|
||||
# max_tokens, temperature are already supported
|
||||
if model_params[:stop_sequences]
|
||||
model_params[:stop] = model_params.delete(:stop_sequences)
|
||||
end
|
||||
|
||||
model_params
|
||||
end
|
||||
|
||||
def default_options
|
||||
{ model: llm_model.name }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
AiApiAuditLog::Provider::SambaNova
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def model_uri
|
||||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
api_key = llm_model.api_key
|
||||
|
||||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||
log.response_tokens = @completion_tokens if @completion_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
end
|
||||
|
||||
parsed = json.dig(:choices, 0)
|
||||
return if !parsed
|
||||
|
||||
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data == "[DONE]" ? nil : data
|
||||
end
|
||||
.compact
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -76,7 +76,17 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def provider_names
|
||||
providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure]
|
||||
providers = %w[
|
||||
aws_bedrock
|
||||
anthropic
|
||||
vllm
|
||||
hugging_face
|
||||
cohere
|
||||
open_ai
|
||||
google
|
||||
azure
|
||||
samba_nova
|
||||
]
|
||||
if !Rails.env.production?
|
||||
providers << "fake"
|
||||
providers << "ollama"
|
||||
|
|
|
@ -71,3 +71,11 @@ Fabricator(:cohere_model, from: :llm_model) do
|
|||
api_key "ABC"
|
||||
url "https://api.cohere.ai/v1/chat"
|
||||
end
|
||||
|
||||
Fabricator(:samba_nova_model, from: :llm_model) do
|
||||
display_name "Samba Nova"
|
||||
name "samba-nova"
|
||||
provider "samba_nova"
|
||||
api_key "ABC"
|
||||
url "https://api.sambanova.ai/v1/chat/completions"
|
||||
end
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::SambaNova do
|
||||
fab!(:llm_model) { Fabricate(:samba_nova_model) }
|
||||
let(:llm) { llm_model.to_llm }
|
||||
|
||||
it "can stream completions" do
|
||||
body = <<~PARTS
|
||||
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [{"index": 0, "delta": {"content": "I am a bot"}, "logprobs": null, "finish_reason": null}]}
|
||||
|
||||
data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [], "usage": {"is_last_response": true, "total_tokens": 62, "prompt_tokens": 21, "completion_tokens": 41, "time_to_first_token": 0.09152531623840332, "end_time": 1721336361.582011, "start_time": 1721336361.413994, "total_latency": 0.16801691055297852, "total_tokens_per_sec": 369.010475171488, "completion_tokens_per_sec": 244.02305616179046, "completion_tokens_after_first_per_sec": 522.9332759819093, "completion_tokens_after_first_per_sec_first_ten": 1016.0004844667837}}
|
||||
|
||||
data: [DONE]
|
||||
PARTS
|
||||
|
||||
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
|
||||
body:
|
||||
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}],\"stream\":true}",
|
||||
headers: {
|
||||
"Authorization" => "Bearer ABC",
|
||||
"Content-Type" => "application/json",
|
||||
},
|
||||
).to_return(status: 200, body: body, headers: {})
|
||||
|
||||
response = +""
|
||||
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
|
||||
|
||||
expect(response).to eq("I am a bot")
|
||||
end
|
||||
|
||||
it "can perform regular completions" do
|
||||
body = { choices: [message: { content: "I am a bot" }] }.to_json
|
||||
|
||||
stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with(
|
||||
body:
|
||||
"{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}]}",
|
||||
headers: {
|
||||
"Authorization" => "Bearer ABC",
|
||||
"Content-Type" => "application/json",
|
||||
},
|
||||
).to_return(status: 200, body: body, headers: {})
|
||||
|
||||
response = llm.generate("who are you?", user: Discourse.system_user)
|
||||
|
||||
expect(response).to eq("I am a bot")
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue