From 5b9add0ac8f0c116af449109c1e4299bfc9bfd4d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 12 Sep 2024 11:28:08 +1000 Subject: [PATCH] FEATURE: add a SambaNova LLM provider (#797) Note, at the moment the context window is quite small, it is mainly useful as a helper backend or hyde generator --- app/models/ai_api_audit_log.rb | 1 + config/locales/client.en.yml | 3 +- .../dialects/open_ai_compatible.rb | 17 ++-- lib/completions/endpoints/base.rb | 1 + lib/completions/endpoints/samba_nova.rb | 84 +++++++++++++++++++ lib/completions/llm.rb | 12 ++- spec/fabricators/llm_model_fabricator.rb | 8 ++ .../completions/endpoints/samba_nova_spec.rb | 47 +++++++++++ 8 files changed, 162 insertions(+), 11 deletions(-) create mode 100644 lib/completions/endpoints/samba_nova.rb create mode 100644 spec/lib/completions/endpoints/samba_nova_spec.rb diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 024ba44b..c38d68eb 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -12,6 +12,7 @@ class AiApiAuditLog < ActiveRecord::Base Vllm = 5 Cohere = 6 Ollama = 7 + SambaNova = 8 end end diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 139e3627..342d2801 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -243,7 +243,7 @@ en: confirm_delete: Are you sure you want to delete this model? delete: Delete seeded_warning: "This model is pre-configured on your site and cannot be edited." - in_use_warning: + in_use_warning: one: "This model is currently used by the %{settings} setting. If misconfigured, the feature won't work as expected." other: "This model is currently used by the following settings: %{settings}. If misconfigured, features won't work as expected. " @@ -275,6 +275,7 @@ en: azure: "Azure" ollama: "Ollama" CDCK: "CDCK" + samba_nova: "SambaNova" provider_fields: access_key_id: "AWS Bedrock Access key ID" diff --git a/lib/completions/dialects/open_ai_compatible.rb b/lib/completions/dialects/open_ai_compatible.rb index 5a679b71..d22d9708 100644 --- a/lib/completions/dialects/open_ai_compatible.rb +++ b/lib/completions/dialects/open_ai_compatible.rb @@ -85,15 +85,14 @@ module DiscourseAi encoded_uploads = prompt.encoded_uploads(message) return content if encoded_uploads.blank? - content_w_imgs = - encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details| - memo << { - type: "image_url", - image_url: { - url: "data:#{details[:mime_type]};base64,#{details[:base64]}", - }, - } - end + encoded_uploads.reduce([{ type: "text", text: message[:content] }]) do |memo, details| + memo << { + type: "image_url", + image_url: { + url: "data:#{details[:mime_type]};base64,#{details[:base64]}", + }, + } + end end end end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 1bcca512..3782d735 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -17,6 +17,7 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Vllm, DiscourseAi::Completions::Endpoints::Anthropic, DiscourseAi::Completions::Endpoints::Cohere, + DiscourseAi::Completions::Endpoints::SambaNova, ] endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development? diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb new file mode 100644 index 00000000..ccb883cc --- /dev/null +++ b/lib/completions/endpoints/samba_nova.rb @@ -0,0 +1,84 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class SambaNova < Base + def self.can_contact?(model_provider) + model_provider == "samba_nova" + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + # max_tokens, temperature are already supported + if model_params[:stop_sequences] + model_params[:stop] = model_params.delete(:stop_sequences) + end + + model_params + end + + def default_options + { model: llm_model.name } + end + + def provider_id + AiApiAuditLog::Provider::SambaNova + end + + private + + def model_uri + URI(llm_model.url) + end + + def prepare_payload(prompt, model_params, dialect) + payload = default_options.merge(model_params).merge(messages: prompt) + + payload[:stream] = true if @streaming_mode + + payload + end + + def prepare_request(payload) + headers = { "Content-Type" => "application/json" } + api_key = llm_model.api_key + + headers["Authorization"] = "Bearer #{api_key}" + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def final_log_update(log) + log.request_tokens = @prompt_tokens if @prompt_tokens + log.response_tokens = @completion_tokens if @completion_tokens + end + + def extract_completion_from(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + end + + parsed = json.dig(:choices, 0) + return if !parsed + + @streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content) + end + + def partials_from(decoded_chunk) + decoded_chunk + .split("\n") + .map do |line| + data = line.split("data: ", 2)[1] + data == "[DONE]" ? nil : data + end + .compact + end + end + end + end +end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 414b1caa..527ec87f 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -76,7 +76,17 @@ module DiscourseAi end def provider_names - providers = %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + providers = %w[ + aws_bedrock + anthropic + vllm + hugging_face + cohere + open_ai + google + azure + samba_nova + ] if !Rails.env.production? providers << "fake" providers << "ollama" diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb index 5dd4adcb..bea2d72c 100644 --- a/spec/fabricators/llm_model_fabricator.rb +++ b/spec/fabricators/llm_model_fabricator.rb @@ -71,3 +71,11 @@ Fabricator(:cohere_model, from: :llm_model) do api_key "ABC" url "https://api.cohere.ai/v1/chat" end + +Fabricator(:samba_nova_model, from: :llm_model) do + display_name "Samba Nova" + name "samba-nova" + provider "samba_nova" + api_key "ABC" + url "https://api.sambanova.ai/v1/chat/completions" +end diff --git a/spec/lib/completions/endpoints/samba_nova_spec.rb b/spec/lib/completions/endpoints/samba_nova_spec.rb new file mode 100644 index 00000000..0f1f68ac --- /dev/null +++ b/spec/lib/completions/endpoints/samba_nova_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::Endpoints::SambaNova do + fab!(:llm_model) { Fabricate(:samba_nova_model) } + let(:llm) { llm_model.to_llm } + + it "can stream completions" do + body = <<~PARTS + data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [{"index": 0, "delta": {"content": "I am a bot"}, "logprobs": null, "finish_reason": null}]} + +data: {"id": "4c5e4a44-e847-467d-b9cd-d2f6530678cd", "object": "chat.completion.chunk", "created": 1721336361, "model": "llama3-8b", "system_fingerprint": "fastcoe", "choices": [], "usage": {"is_last_response": true, "total_tokens": 62, "prompt_tokens": 21, "completion_tokens": 41, "time_to_first_token": 0.09152531623840332, "end_time": 1721336361.582011, "start_time": 1721336361.413994, "total_latency": 0.16801691055297852, "total_tokens_per_sec": 369.010475171488, "completion_tokens_per_sec": 244.02305616179046, "completion_tokens_after_first_per_sec": 522.9332759819093, "completion_tokens_after_first_per_sec_first_ten": 1016.0004844667837}} + +data: [DONE] + PARTS + + stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with( + body: + "{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}],\"stream\":true}", + headers: { + "Authorization" => "Bearer ABC", + "Content-Type" => "application/json", + }, + ).to_return(status: 200, body: body, headers: {}) + + response = +"" + llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial } + + expect(response).to eq("I am a bot") + end + + it "can perform regular completions" do + body = { choices: [message: { content: "I am a bot" }] }.to_json + + stub_request(:post, "https://api.sambanova.ai/v1/chat/completions").with( + body: + "{\"model\":\"samba-nova\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful bot\"},{\"role\":\"user\",\"content\":\"who are you?\"}]}", + headers: { + "Authorization" => "Bearer ABC", + "Content-Type" => "application/json", + }, + ).to_return(status: 200, body: body, headers: {}) + + response = llm.generate("who are you?", user: Discourse.system_user) + + expect(response).to eq("I am a bot") + end +end