FEATURE: add internal support for streaming mode (#42)

Also adds some tests around completions and supports additional params
such as top_p, temperature and max_tokens

This also migrates off Faraday to using Net::HTTP directly
This commit is contained in:
Sam 2023-04-21 16:54:25 +10:00 committed by GitHub
parent 14b21b4f4d
commit 057fbe1ce6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 23 deletions

View File

@ -3,31 +3,88 @@
module ::DiscourseAi
module Inference
class OpenAiCompletions
TIMEOUT = 60
CompletionFailed = Class.new(StandardError)
def self.perform!(messages, model = SiteSetting.ai_helper_model)
headers = {
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
"Content-Type" => "application/json",
}
connection_opts = { request: { write_timeout: 60, read_timeout: 60, open_timeout: 60 } }
response =
Faraday.new(nil, connection_opts).post(
"https://api.openai.com/v1/chat/completions",
{ model: model, messages: messages }.to_json,
headers,
def self.perform!(
messages,
model = SiteSetting.ai_helper_model,
temperature: nil,
top_p: nil,
max_tokens: nil,
stream: false,
&blk
)
raise ArgumentError, "block must be supplied in streaming mode" if stream && !blk
if response.status != 200
url = URI("https://api.openai.com/v1/chat/completions")
headers = {
"Content-Type": "application/json",
Authorization: "Bearer #{SiteSetting.ai_openai_api_key}",
}
payload = { model: model, messages: messages }
payload[:temperature] = temperature if temperature
payload[:top_p] = top_p if top_p
payload[:max_tokens] = max_tokens if max_tokens
payload[:stream] = true if stream
Net::HTTP.start(
url.host,
url.port,
use_ssl: true,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request.body = payload.to_json
response = http.request(request)
if response.code.to_i != 200
Rails.logger.error(
"OpenAiCompletions: status: #{response.status} - body: #{response.body}",
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
JSON.parse(response.body, symbolize_names: true)
if stream
stream(http, response, &blk)
else
JSON.parse(response.read_body, symbolize_names: true)
end
end
end
def self.stream(http, response)
cancelled = false
cancel = lambda { cancelled = true }
response.read_body do |chunk|
if cancelled
http.finish
break
end
chunk
.split("\n")
.each do |line|
data = line.split("data: ", 2)[1]
next if data == "[DONE]"
yield JSON.parse(data, symbolize_names: true), cancel if data
if cancelled
http.finish
break
end
end
end
rescue IOError
raise if !cancelled
end
end
end

View File

@ -0,0 +1,92 @@
# frozen_string_literal: true
require "rails_helper"
describe DiscourseAi::Inference::OpenAiCompletions do
before { SiteSetting.ai_openai_api_key = "abc-123" }
it "can complete a trivial prompt" do
body = <<~JSON
{"id":"chatcmpl-74OT0yKnvbmTkqyBINbHgAW0fpbxc","object":"chat.completion","created":1681281718,"model":"gpt-3.5-turbo-0301","usage":{"prompt_tokens":12,"completion_tokens":13,"total_tokens":25},"choices":[{"message":{"role":"assistant","content":"1. Serenity\\n2. Laughter\\n3. Adventure"},"finish_reason":"stop","index":0}]}
JSON
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
body:
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"temperature\":0.5,\"top_p\":0.8,\"max_tokens\":700}",
headers: {
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
"Content-Type" => "application/json",
},
).to_return(status: 200, body: body, headers: {})
prompt = [role: "user", content: "write 3 words"]
completions =
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
"gpt-3.5-turbo",
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
)
expect(completions[:choices][0][:message][:content]).to eq(
"1. Serenity\n2. Laughter\n3. Adventure",
)
end
it "raises an error if attempting to stream without a block" do
expect do
DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo", stream: true)
end.to raise_error(ArgumentError)
end
def stream_line(finish_reason: nil, delta: {})
+"data: " << {
id: "chatcmpl-#{SecureRandom.hex}",
object: "chat.completion.chunk",
created: 1_681_283_881,
model: "gpt-3.5-turbo-0301",
choices: [{ delta: delta }],
finish_reason: finish_reason,
index: 0,
}.to_json
end
it "can operate in streaming mode" do
payload = [
stream_line(delta: { role: "assistant" }),
stream_line(delta: { content: "Mount" }),
stream_line(delta: { content: "ain" }),
stream_line(delta: { content: " " }),
stream_line(delta: { content: "Tree " }),
stream_line(delta: { content: "Frog" }),
stream_line(finish_reason: "stop"),
"[DONE]",
].join("\n\n")
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
body:
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"stream\":true}",
headers: {
"Accept" => "*/*",
"Authorization" => "Bearer abc-123",
"Content-Type" => "application/json",
"Host" => "api.openai.com",
},
).to_return(status: 200, body: payload, headers: {})
prompt = [role: "user", content: "write 3 words"]
content = +""
DiscourseAi::Inference::OpenAiCompletions.perform!(
prompt,
"gpt-3.5-turbo",
stream: true,
) do |partial, cancel|
data = partial[:choices][0].dig(:delta, :content)
content << data if data
cancel.call if content.split(" ").length == 2
end
expect(content).to eq("Mountain Tree ")
end
end

View File

@ -94,7 +94,6 @@ class OpenAiCompletionsInferenceStubs
end
def stub_prompt(type)
prompt_builder = DiscourseAi::AiHelper::LlmPrompt.new
text = type == TRANSLATE ? spanish_text : translated_response
prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text)