FEATURE: Add Azure cognitive service support (#93)
The new site settings: ai_openai_gpt35_url : distribution for GPT 16k ai_openai_gpt4_url: distribution for GPT 4 ai_openai_embeddings_url: distribution for ada2 If untouched we will simply use OpenAI endpoints. Azure requires 1 URL per model, OpenAI allows a single URL to serve multiple models. Hence the new settings.
This commit is contained in:
parent
30778d8af8
commit
d1ab79e82f
|
@ -31,6 +31,9 @@ en:
|
||||||
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW."
|
||||||
ai_nsfw_models: "Models to use for NSFW inference."
|
ai_nsfw_models: "Models to use for NSFW inference."
|
||||||
|
|
||||||
|
ai_openai_gpt35_url: "Custom URL used for GPT 3.5 chat completions. (Azure: MUST support function calling and ideally is a GPT3.5 16K endpoint)"
|
||||||
|
ai_openai_gpt4_url: "Custom URL used for GPT 4 chat completions. (for Azure support)"
|
||||||
|
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
||||||
ai_openai_api_key: "API key for OpenAI API"
|
ai_openai_api_key: "API key for OpenAI API"
|
||||||
ai_anthropic_api_key: "API key for Anthropic API"
|
ai_anthropic_api_key: "API key for Anthropic API"
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,9 @@ plugins:
|
||||||
- opennsfw2
|
- opennsfw2
|
||||||
- nsfw_detector
|
- nsfw_detector
|
||||||
|
|
||||||
|
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
|
||||||
|
ai_openai_gpt4_url: "https://api.openai.com/v1/chat/completions"
|
||||||
|
ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings"
|
||||||
ai_openai_api_key:
|
ai_openai_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
|
|
@ -60,11 +60,19 @@ module ::DiscourseAi
|
||||||
functions: nil,
|
functions: nil,
|
||||||
user_id: nil
|
user_id: nil
|
||||||
)
|
)
|
||||||
url = URI("https://api.openai.com/v1/chat/completions")
|
url =
|
||||||
headers = {
|
if model.include?("gpt-4")
|
||||||
"Content-Type" => "application/json",
|
URI(SiteSetting.ai_openai_gpt4_url)
|
||||||
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
|
else
|
||||||
}
|
URI(SiteSetting.ai_openai_gpt35_url)
|
||||||
|
end
|
||||||
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
if url.host.include? ("azure")
|
||||||
|
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||||
|
else
|
||||||
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||||
|
end
|
||||||
|
|
||||||
payload = { model: model, messages: messages }
|
payload = { model: model, messages: messages }
|
||||||
|
|
||||||
|
|
|
@ -4,21 +4,28 @@ module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class OpenAiEmbeddings
|
class OpenAiEmbeddings
|
||||||
def self.perform!(content, model = nil)
|
def self.perform!(content, model = nil)
|
||||||
headers = {
|
headers = { "Content-Type" => "application/json" }
|
||||||
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
|
|
||||||
"Content-Type" => "application/json",
|
if SiteSetting.ai_openai_embeddings_url.include? ("azure")
|
||||||
}
|
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||||
|
else
|
||||||
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||||
|
end
|
||||||
|
|
||||||
model ||= "text-embedding-ada-002"
|
model ||= "text-embedding-ada-002"
|
||||||
|
|
||||||
response =
|
response =
|
||||||
Faraday.post(
|
Faraday.post(
|
||||||
"https://api.openai.com/v1/embeddings",
|
SiteSetting.ai_openai_embeddings_url,
|
||||||
{ model: model, input: content }.to_json,
|
{ model: model, input: content }.to_json,
|
||||||
headers,
|
headers,
|
||||||
)
|
)
|
||||||
|
if response.status != 200
|
||||||
raise Net::HTTPBadResponse unless response.status == 200
|
Rails.logger.warn(
|
||||||
|
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
|
)
|
||||||
|
raise Net::HTTPBadResponse
|
||||||
|
end
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,6 +6,57 @@ require_relative "../../support/openai_completions_inference_stubs"
|
||||||
describe DiscourseAi::Inference::OpenAiCompletions do
|
describe DiscourseAi::Inference::OpenAiCompletions do
|
||||||
before { SiteSetting.ai_openai_api_key = "abc-123" }
|
before { SiteSetting.ai_openai_api_key = "abc-123" }
|
||||||
|
|
||||||
|
context "when configured using Azure" do
|
||||||
|
it "Supports GPT 3.5 completions" do
|
||||||
|
SiteSetting.ai_openai_api_key = "12345"
|
||||||
|
SiteSetting.ai_openai_gpt35_url =
|
||||||
|
"https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview"
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
id: "chatcmpl-7TfPzOyBGW5K6dyWp3NPU0mYLGZRQ",
|
||||||
|
object: "chat.completion",
|
||||||
|
created: 1_687_305_079,
|
||||||
|
model: "gpt-35-turbo",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
finish_reason: "stop",
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: "Hi there! How can I assist you today?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
completion_tokens: 10,
|
||||||
|
prompt_tokens: 9,
|
||||||
|
total_tokens: 19,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stub_request(
|
||||||
|
:post,
|
||||||
|
"https://company.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2023-03-15-preview",
|
||||||
|
).with(
|
||||||
|
body:
|
||||||
|
"{\"model\":\"gpt-3.5-turbo-0613\",\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}",
|
||||||
|
headers: {
|
||||||
|
"Api-Key" => "12345",
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
"Host" => "company.openai.azure.com",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: expected.to_json, headers: {})
|
||||||
|
|
||||||
|
result =
|
||||||
|
DiscourseAi::Inference::OpenAiCompletions.perform!(
|
||||||
|
[role: "user", content: "hello"],
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result).to eq(expected)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
it "supports function calling" do
|
it "supports function calling" do
|
||||||
prompt = [role: "system", content: "you are weatherbot"]
|
prompt = [role: "system", content: "you are weatherbot"]
|
||||||
prompt << { role: "user", content: "what is the weather in sydney?" }
|
prompt << { role: "user", content: "what is the weather in sydney?" }
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
|
it "supports azure embeddings" do
|
||||||
|
SiteSetting.ai_openai_embeddings_url =
|
||||||
|
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15"
|
||||||
|
SiteSetting.ai_openai_api_key = "123456"
|
||||||
|
|
||||||
|
body_json = {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 1,
|
||||||
|
total_tokens: 1,
|
||||||
|
},
|
||||||
|
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||||
|
}.to_json
|
||||||
|
|
||||||
|
stub_request(
|
||||||
|
:post,
|
||||||
|
"https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15",
|
||||||
|
).with(
|
||||||
|
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
||||||
|
headers: {
|
||||||
|
"Api-Key" => "123456",
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
|
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
||||||
|
|
||||||
|
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||||
|
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||||
|
end
|
||||||
|
|
||||||
|
it "supports openai embeddings" do
|
||||||
|
SiteSetting.ai_openai_api_key = "123456"
|
||||||
|
|
||||||
|
body_json = {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 1,
|
||||||
|
total_tokens: 1,
|
||||||
|
},
|
||||||
|
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||||
|
}.to_json
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
||||||
|
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
||||||
|
headers: {
|
||||||
|
"Authorization" => "Bearer 123456",
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
|
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
||||||
|
|
||||||
|
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||||
|
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue