From d1ab79e82ff723c3b73f4268143222ed0f930be8 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 21 Jun 2023 10:39:51 +1000 Subject: [PATCH] FEATURE: Add Azure cognitive service support (#93) The new site settings: ai_openai_gpt35_url : distribution for GPT 16k ai_openai_gpt4_url: distribution for GPT 4 ai_openai_embeddings_url: distribution for ada2 If untouched we will simply use OpenAI endpoints. Azure requires 1 URL per model, OpenAI allows a single URL to serve multiple models. Hence the new settings. --- config/locales/server.en.yml | 3 + config/settings.yml | 3 + lib/shared/inference/openai_completions.rb | 18 ++++-- lib/shared/inference/openai_embeddings.rb | 21 ++++--- .../inference/openai_completions_spec.rb | 51 ++++++++++++++++ .../inference/openai_embeddings_spec.rb | 59 +++++++++++++++++++ 6 files changed, 143 insertions(+), 12 deletions(-) create mode 100644 spec/shared/inference/openai_embeddings_spec.rb diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 5eebaf66..5315a258 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -31,6 +31,9 @@ en: ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW." ai_nsfw_models: "Models to use for NSFW inference." + ai_openai_gpt35_url: "Custom URL used for GPT 3.5 chat completions. (Azure: MUST support function calling and ideally is a GPT3.5 16K endpoint)" + ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)" + ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" ai_anthropic_api_key: "API key for Anthropic API" diff --git a/config/settings.yml b/config/settings.yml index 077b97ea..13eca12d 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -88,6 +88,9 @@ plugins: - opennsfw2 - nsfw_detector + ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" + ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions" + ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings" ai_openai_api_key: default: "" secret: true diff --git a/lib/shared/inference/openai_completions.rb b/lib/shared/inference/openai_completions.rb index f0c1a2d1..cdfb826b 100644 --- a/lib/shared/inference/openai_completions.rb +++ b/lib/shared/inference/openai_completions.rb @@ -60,11 +60,19 @@ module ::DiscourseAi functions: nil, user_id: nil ) - url = URI("https://api.openai.com/v1/chat/completions") - headers = { - "Content-Type" => "application/json", - "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", - } + url = + if model.include?("gpt-4") + URI(SiteSetting.ai_openai_gpt4_url) + else + URI(SiteSetting.ai_openai_gpt35_url) + end + headers = { "Content-Type" => "application/json" } + + if url.host.include? ("azure") + headers["api-key"] = SiteSetting.ai_openai_api_key + else + headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" + end payload = { model: model, messages: messages } diff --git a/lib/shared/inference/openai_embeddings.rb b/lib/shared/inference/openai_embeddings.rb index 31856a5d..2cdacb1f 100644 --- a/lib/shared/inference/openai_embeddings.rb +++ b/lib/shared/inference/openai_embeddings.rb @@ -4,21 +4,28 @@ module ::DiscourseAi module Inference class OpenAiEmbeddings def self.perform!(content, model = nil) - headers = { - "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", - "Content-Type" => "application/json", - } + headers = { "Content-Type" => "application/json" } + + if SiteSetting.ai_openai_embeddings_url.include? ("azure") + headers["api-key"] = SiteSetting.ai_openai_api_key + else + headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" + end model ||= "text-embedding-ada-002" response = Faraday.post( - "https://api.openai.com/v1/embeddings", + SiteSetting.ai_openai_embeddings_url, { model: model, input: content }.to_json, headers, ) - - raise Net::HTTPBadResponse unless response.status == 200 + if response.status != 200 + Rails.logger.warn( + "OpenAI Embeddings failed with status: #{response.status} body: #{response.body}", + ) + raise Net::HTTPBadResponse + end JSON.parse(response.body, symbolize_names: true) end diff --git a/spec/shared/inference/openai_completions_spec.rb b/spec/shared/inference/openai_completions_spec.rb index 6753c70a..dca271ba 100644 --- a/spec/shared/inference/openai_completions_spec.rb +++ b/spec/shared/inference/openai_completions_spec.rb @@ -6,6 +6,57 @@ require_relative "../../support/openai_completions_inference_stubs" describe DiscourseAi::Inference::OpenAiCompletions do before { SiteSetting.ai_openai_api_key = "abc-123" } + context "when configured using Azure" do + it "Supports GPT 3.5 completions" do + SiteSetting.ai_openai_api_key = "12345" + SiteSetting.ai_openai_gpt35_url = + "https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview" + + expected = { + id: "chatcmpl-7TfPzOyBGW5K6dyWp3NPU0mYLGZRQ", + object: "chat.completion", + created: 1_687_305_079, + model: "gpt-35-turbo", + choices: [ + { + index: 0, + finish_reason: "stop", + message: { + role: "assistant", + content: "Hi there! How can I assist you today?", + }, + }, + ], + usage: { + completion_tokens: 10, + prompt_tokens: 9, + total_tokens: 19, + }, + } + + stub_request( + :post, + "https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview", + ).with( + body: + "{\"model\":\"gpt-3.5-turbo-0613\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}", + headers: { + "Api-Key" => "12345", + "Content-Type" => "application/json", + "Host" => "company.openai.azure.com", + }, + ).to_return(status: 200, body: expected.to_json, headers: {}) + + result = + DiscourseAi::Inference::OpenAiCompletions.perform!( + [role: "user", content: "hello"], + "gpt-3.5-turbo-0613", + ) + + expect(result).to eq(expected) + end + end + it "supports function calling" do prompt = [role: "system", content: "you are weatherbot"] prompt << { role: "user", content: "what is the weather in sydney?" } diff --git a/spec/shared/inference/openai_embeddings_spec.rb b/spec/shared/inference/openai_embeddings_spec.rb new file mode 100644 index 00000000..2b922d1f --- /dev/null +++ b/spec/shared/inference/openai_embeddings_spec.rb @@ -0,0 +1,59 @@ +# frozen_string_literal: true +require "rails_helper" + +describe DiscourseAi::Inference::OpenAiEmbeddings do + it "supports azure embeddings" do + SiteSetting.ai_openai_embeddings_url = + "https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15" + SiteSetting.ai_openai_api_key = "123456" + + body_json = { + usage: { + prompt_tokens: 1, + total_tokens: 1, + }, + data: [{ object: "embedding", embedding: [0.0, 0.1] }], + }.to_json + + stub_request( + :post, + "https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15", + ).with( + body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}", + headers: { + "Api-Key" => "123456", + "Content-Type" => "application/json", + }, + ).to_return(status: 200, body: body_json, headers: {}) + + result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello") + + expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) + expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + end + + it "supports openai embeddings" do + SiteSetting.ai_openai_api_key = "123456" + + body_json = { + usage: { + prompt_tokens: 1, + total_tokens: 1, + }, + data: [{ object: "embedding", embedding: [0.0, 0.1] }], + }.to_json + + stub_request(:post, "https://api.openai.com/v1/embeddings").with( + body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}", + headers: { + "Authorization" => "Bearer 123456", + "Content-Type" => "application/json", + }, + ).to_return(status: 200, body: body_json, headers: {}) + + result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello") + + expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) + expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + end +end