FEATURE: add a table to audit OpenAI usage (#45)

Still need to build a job to purge logs
This commit is contained in:
Sam 2023-04-26 11:44:29 +10:00 committed by GitHub
parent f6c30e8df9
commit 2cd60a4b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 5 deletions

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class AiApiAuditLog < ActiveRecord::Base
module Provider
OpenAI = 1
end
end

View File

@ -0,0 +1,15 @@
# frozen_string_literal: true
class CreateAiApiAuditLogs < ActiveRecord::Migration[7.0]
def change
create_table :ai_api_audit_logs do |t|
t.integer :provider_id, null: false
t.integer :user_id
t.integer :request_tokens
t.integer :response_tokens
t.string :raw_request_payload
t.string :raw_response_payload
t.timestamps
end
end
end

View File

@ -14,6 +14,7 @@ module ::DiscourseAi
top_p: nil,
max_tokens: nil,
stream: false,
user_id: nil,
&blk
)
raise ArgumentError, "block must be supplied in streaming mode" if stream && !blk
@ -39,7 +40,8 @@ module ::DiscourseAi
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request.body = payload.to_json
request_body = payload.to_json
request.body = request_body
response = http.request(request)
@ -50,24 +52,44 @@ module ::DiscourseAi
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::OpenAI,
raw_request_payload: request_body,
user_id: user_id,
)
if stream
stream(http, response, &blk)
stream(http, response, messages, log, &blk)
else
JSON.parse(response.read_body, symbolize_names: true)
response_body = response.body
parsed = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: parsed.dig(:usage, :prompt_tokens),
response_tokens: parsed.dig(:usage, :completion_tokens),
)
parsed
end
end
end
def self.stream(http, response)
def self.stream(http, response, messages, log)
cancelled = false
cancel = lambda { cancelled = true }
response_data = +""
response_raw = +""
response.read_body do |chunk|
if cancelled
http.finish
break
end
response_raw << chunk
chunk
.split("\n")
.each do |line|
@ -75,7 +97,15 @@ module ::DiscourseAi
next if data == "[DONE]"
yield JSON.parse(data, symbolize_names: true), cancel if data
if data
json = JSON.parse(data, symbolize_names: true)
choices = json[:choices]
if choices && choices[0]
delta = choices[0].dig(:delta, :content)
response_data << delta if delta
end
yield json, cancel
end
if cancelled
http.finish
@ -85,6 +115,16 @@ module ::DiscourseAi
end
rescue IOError
raise if !cancelled
ensure
log.update!(
raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer.size(extract_prompt(messages)),
response_tokens: DiscourseAi::Tokenizer.size(response_data),
)
end
def self.extract_prompt(messages)
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
end
end

View File

@ -18,6 +18,8 @@ describe DiscourseAi::Inference::OpenAiCompletions do
},
).to_return(status: 200, body: body, headers: {})
user_id = 183
prompt = [role: "user", content: "write 3 words"]
completions =
DiscourseAi::Inference::OpenAiCompletions.perform!(
@ -26,10 +28,24 @@ describe DiscourseAi::Inference::OpenAiCompletions do
temperature: 0.5,
top_p: 0.8,
max_tokens: 700,
user_id: user_id,
)
expect(completions[:choices][0][:message][:content]).to eq(
"1. Serenity\n2. Laughter\n3. Adventure",
)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
request_body = (<<~JSON).strip
{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"write 3 words"}],"temperature":0.5,"top_p":0.8,"max_tokens":700}
JSON
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
expect(log.request_tokens).to eq(12)
expect(log.response_tokens).to eq(13)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to eq(body)
end
it "raises an error if attempting to stream without a block" do
@ -88,5 +104,18 @@ describe DiscourseAi::Inference::OpenAiCompletions do
end
expect(content).to eq("Mountain Tree ")
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
request_body = (<<~JSON).strip
{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"write 3 words"}],"stream":true}
JSON
expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI)
expect(log.request_tokens).to eq(5)
expect(log.response_tokens).to eq(4)
expect(log.raw_request_payload).to eq(request_body)
expect(log.raw_response_payload).to eq(payload)
end
end